remove defaultmodel from config file
This commit is contained in:
parent
a525fba1d2
commit
feabfaed01
13 changed files with 89 additions and 80 deletions
|
|
@ -3,9 +3,6 @@ import os
|
|||
import yaml
|
||||
from loguru import logger
|
||||
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.models import get_model
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
config_path = os.path.join(project_root, "open_notebook_config.yaml")
|
||||
|
|
@ -34,22 +31,3 @@ os.makedirs(UPLOADS_FOLDER, exist_ok=True)
|
|||
# PODCASTS FOLDER
|
||||
PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts"
|
||||
os.makedirs(PODCASTS_FOLDER, exist_ok=True)
|
||||
|
||||
|
||||
def load_default_models():
|
||||
default_models = DefaultModels.load()
|
||||
embedding_model = (
|
||||
get_model(default_models.default_embedding_model, model_type="embedding")
|
||||
if default_models.default_embedding_model
|
||||
else None
|
||||
)
|
||||
|
||||
speech_to_text_model = (
|
||||
get_model(
|
||||
default_models.default_speech_to_text_model, model_type="speech_to_text"
|
||||
)
|
||||
if default_models.default_speech_to_text_model
|
||||
else None
|
||||
)
|
||||
|
||||
return default_models, embedding_model, speech_to_text_model
|
||||
|
|
|
|||
|
|
@ -18,8 +18,16 @@ class MigrationManager:
|
|||
database=os.environ["SURREAL_DATABASE"],
|
||||
encrypted=False, # Set to True if using SSL
|
||||
)
|
||||
self.up_migrations = [Migration.from_file("migrations/1.surrealql")]
|
||||
self.down_migrations = [Migration.from_file("migrations/1_down.surrealql")]
|
||||
self.up_migrations = [
|
||||
Migration.from_file("migrations/1.surrealql"),
|
||||
Migration.from_file("migrations/2.surrealql"),
|
||||
]
|
||||
self.down_migrations = [
|
||||
Migration.from_file(
|
||||
"migrations/1_down.surrealql",
|
||||
),
|
||||
Migration.from_file("migrations/2_down.surrealql"),
|
||||
]
|
||||
self.runner = MigrationRunner(
|
||||
up_migrations=self.up_migrations,
|
||||
down_migrations=self.down_migrations,
|
||||
|
|
|
|||
|
|
@ -68,9 +68,11 @@ class ObjectModel(BaseModel):
|
|||
return None
|
||||
|
||||
def save(self) -> None:
|
||||
from open_notebook.config import load_default_models
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.models import get_model
|
||||
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
|
||||
|
||||
try:
|
||||
logger.debug(f"Validating {self.__class__.__name__}")
|
||||
|
|
@ -88,7 +90,11 @@ class ObjectModel(BaseModel):
|
|||
logger.debug("Creating new record")
|
||||
repo_result = repo_create(self.__class__.table_name, data)
|
||||
else:
|
||||
data["created"] = self.created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
data["created"] = (
|
||||
self.created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if type(self.created) == datetime
|
||||
else self.created
|
||||
)
|
||||
logger.debug(f"Updating record with id {self.id}")
|
||||
repo_result = repo_update(self.id, data)
|
||||
|
||||
|
|
@ -148,3 +154,20 @@ class ObjectModel(BaseModel):
|
|||
if isinstance(value, str):
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return value
|
||||
|
||||
|
||||
class RecordModel(BaseModel):
|
||||
record_id: ClassVar[str] = "open_notebook:default_models"
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
result = repo_query(f"SELECT * FROM {cls.record_id};")
|
||||
if result:
|
||||
result = result[0]
|
||||
dm = cls(**result)
|
||||
return dm
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def update(self, data):
|
||||
repo_update(self.record_id, data)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,7 @@
|
|||
from typing import ClassVar, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_notebook.database.repository import (
|
||||
repo_query,
|
||||
repo_update,
|
||||
)
|
||||
from open_notebook.domain.base import ObjectModel
|
||||
from open_notebook.database.repository import repo_query
|
||||
from open_notebook.domain.base import ObjectModel, RecordModel
|
||||
|
||||
|
||||
class Model(ObjectModel):
|
||||
|
|
@ -23,7 +18,10 @@ class Model(ObjectModel):
|
|||
return [Model(**model) for model in models]
|
||||
|
||||
|
||||
class DefaultModels(BaseModel):
|
||||
# todo: future: colocar um cache aqui
|
||||
class DefaultModels(RecordModel):
|
||||
record_id: ClassVar[str] = "open_notebook:default_models"
|
||||
|
||||
default_chat_model: Optional[str] = None
|
||||
default_transformation_model: Optional[str] = None
|
||||
large_context_model: Optional[str] = None
|
||||
|
|
@ -32,15 +30,21 @@ class DefaultModels(BaseModel):
|
|||
# default_vision_model: Optional[str] = None
|
||||
default_embedding_model: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def load(self):
|
||||
result = repo_query("SELECT * FROM open_notebook:default_models;")
|
||||
if result:
|
||||
result = result[0]
|
||||
dm = DefaultModels(**result)
|
||||
return dm
|
||||
return DefaultModels()
|
||||
|
||||
@classmethod
|
||||
def update(self, data):
|
||||
repo_update("open_notebook:default_models", data)
|
||||
# def load_default_models():
|
||||
# default_models = DefaultModels.load()
|
||||
# embedding_model = (
|
||||
# get_model(default_models.default_embedding_model, model_type="embedding")
|
||||
# if default_models.default_embedding_model
|
||||
# else None
|
||||
# )
|
||||
|
||||
# speech_to_text_model = (
|
||||
# get_model(
|
||||
# default_models.default_speech_to_text_model, model_type="speech_to_text"
|
||||
# )
|
||||
# if default_models.default_speech_to_text_model
|
||||
# else None
|
||||
# )
|
||||
|
||||
# return default_models, embedding_model, speech_to_text_model
|
||||
|
|
|
|||
|
|
@ -4,18 +4,19 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional
|
|||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from open_notebook.config import load_default_models
|
||||
from open_notebook.database.repository import (
|
||||
repo_create,
|
||||
repo_query,
|
||||
)
|
||||
from open_notebook.domain.base import ObjectModel
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.exceptions import (
|
||||
DatabaseOperationError,
|
||||
InvalidInputError,
|
||||
)
|
||||
|
||||
# from temp.recursive_toc import graph as toc_graph
|
||||
from open_notebook.models import get_model
|
||||
from open_notebook.utils import split_text, surreal_clean
|
||||
|
||||
|
||||
|
|
@ -139,7 +140,8 @@ class Source(ObjectModel):
|
|||
raise DatabaseOperationError(e)
|
||||
|
||||
def vectorize(self) -> None:
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
|
||||
|
||||
try:
|
||||
if not self.full_text:
|
||||
|
|
@ -190,7 +192,8 @@ class Source(ObjectModel):
|
|||
raise DatabaseOperationError("Failed to search sources")
|
||||
|
||||
def add_insight(self, insight_type: str, content: str) -> Any:
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
|
||||
|
||||
if not insight_type or not content:
|
||||
raise InvalidInputError("Insight type and content must be provided")
|
||||
|
|
@ -214,7 +217,7 @@ class Source(ObjectModel):
|
|||
class Note(ObjectModel):
|
||||
table_name: ClassVar[str] = "note"
|
||||
title: Optional[str] = None
|
||||
note_type: Optional[Literal["human", "ai"]] = "human"
|
||||
note_type: Optional[Literal["human", "ai"]] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
@field_validator("content")
|
||||
|
|
|
|||
|
|
@ -9,11 +9,12 @@ from langgraph.graph import END, START, StateGraph
|
|||
from langgraph.graph.message import add_messages
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE, load_default_models
|
||||
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.domain.notebook import Notebook
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
|
||||
|
||||
class ThreadState(TypedDict):
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ from math import ceil
|
|||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
|
||||
from open_notebook.config import load_default_models
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.graphs.content_processing.state import SourceState
|
||||
from open_notebook.models import get_model
|
||||
|
||||
# future: parallelize the transcription process
|
||||
|
||||
|
|
@ -72,7 +73,8 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
|
|||
|
||||
|
||||
def extract_audio(data: SourceState):
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
SPEECH_TO_TEXT_MODEL = get_model(DEFAULT_MODELS.default_speech_to_text_model)
|
||||
|
||||
input_audio_path = data.get("file_path")
|
||||
audio_files = []
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from langchain_core.runnables import (
|
|||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from open_notebook.config import load_default_models
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
|
||||
|
||||
class PatternChainState(TypedDict):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langchain.output_parsers import OutputFixingParser
|
||||
from loguru import logger
|
||||
|
||||
from open_notebook.config import load_default_models
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.models import get_model
|
||||
from open_notebook.prompter import Prompter
|
||||
from open_notebook.utils import token_count
|
||||
|
|
@ -18,7 +18,7 @@ def run_pattern(
|
|||
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
|
||||
data=state
|
||||
)
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
tokens = token_count(str(system_prompt) + str(messages))
|
||||
|
||||
if tokens > 105_000:
|
||||
|
|
@ -33,12 +33,12 @@ def run_pattern(
|
|||
or DEFAULT_MODELS.default_chat_model
|
||||
)
|
||||
|
||||
chain = get_model(model_id, model_type="language")
|
||||
chain = get_model(model_id)
|
||||
if parser:
|
||||
chain = chain | parser
|
||||
|
||||
if output_fixing_model_id and parser:
|
||||
output_fix_model = get_model(output_fixing_model_id, model_type="language")
|
||||
output_fix_model = get_model(output_fixing_model_id)
|
||||
chain = chain | OutputFixingParser.from_llm(
|
||||
parser=parser,
|
||||
llm=output_fix_model,
|
||||
|
|
|
|||
|
|
@ -49,13 +49,12 @@ MODEL_CLASS_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def get_model(model_id, model_type="language", **kwargs):
|
||||
def get_model(model_id, **kwargs):
|
||||
"""
|
||||
Get a model instance based on model_id and type.
|
||||
|
||||
Args:
|
||||
model_id: The ID of the model to retrieve
|
||||
model_type: Type of model ('language', 'embedding', or 'speech_to_text')
|
||||
**kwargs: Additional arguments to pass to the model constructor
|
||||
"""
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
|
|
@ -64,20 +63,20 @@ def get_model(model_id, model_type="language", **kwargs):
|
|||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
|
||||
if model_type not in MODEL_CLASS_MAP:
|
||||
raise ValueError(f"Invalid model type: {model_type}")
|
||||
if not model.type or model.type not in MODEL_CLASS_MAP:
|
||||
raise ValueError(f"Invalid model type: {model.type}")
|
||||
|
||||
provider_map = MODEL_CLASS_MAP[model_type]
|
||||
provider_map = MODEL_CLASS_MAP[model.type]
|
||||
if model.provider not in provider_map:
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with {model_type} models"
|
||||
f"Provider {model.provider} not compatible with {model.type} models"
|
||||
)
|
||||
|
||||
model_class = provider_map[model.provider]
|
||||
model_instance = model_class(model_name=model.name, **kwargs)
|
||||
|
||||
# Special handling for language models that need langchain conversion
|
||||
if model_type == "language":
|
||||
if model.type == "language":
|
||||
return model_instance.to_langchain()
|
||||
|
||||
return model_instance
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import streamlit as st
|
||||
from humanize import naturaltime
|
||||
|
||||
from open_notebook.config import load_default_models
|
||||
from open_notebook.domain.notebook import Notebook
|
||||
from stream_app.chat import chat_sidebar
|
||||
from stream_app.note import add_note, note_card
|
||||
|
|
@ -71,9 +70,6 @@ def notebook_page(current_notebook_id):
|
|||
sources = current_notebook.sources
|
||||
notes = current_notebook.notes
|
||||
|
||||
# Load the default models dynamically
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
|
||||
notebook_header(current_notebook)
|
||||
|
||||
work_tab, chat_tab = st.columns([4, 2])
|
||||
|
|
|
|||
|
|
@ -111,6 +111,5 @@ def chat_sidebar(session_id):
|
|||
make_note_from_chat(
|
||||
content=msg.content,
|
||||
notebook_id=st.session_state[session_id]["notebook"].id,
|
||||
type="ai",
|
||||
)
|
||||
st.rerun()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import yaml
|
|||
from humanize import naturaltime
|
||||
from loguru import logger
|
||||
|
||||
from open_notebook.config import UPLOADS_FOLDER, load_default_models
|
||||
from open_notebook.config import UPLOADS_FOLDER
|
||||
from open_notebook.domain.notebook import Asset, Source
|
||||
from open_notebook.exceptions import UnsupportedTypeException
|
||||
from open_notebook.graphs.content_processing import graph
|
||||
|
|
@ -16,8 +16,6 @@ from open_notebook.utils import surreal_clean
|
|||
|
||||
from .consts import context_icons
|
||||
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
|
||||
|
||||
def run_patterns(input_text, patterns):
|
||||
output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns))
|
||||
|
|
@ -26,18 +24,16 @@ def run_patterns(input_text, patterns):
|
|||
|
||||
# moved it here to replace it with the pipeline on 0.1.0
|
||||
def generate_toc_and_title(source) -> "Source":
|
||||
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
|
||||
|
||||
try:
|
||||
patterns = ["patterns/default/toc"]
|
||||
result = run_patterns(source.full_text, patterns=patterns)
|
||||
source.add_insight("Table of Contents", surreal_clean(result))
|
||||
if not source.title:
|
||||
transformations = [
|
||||
patterns = [
|
||||
"Based on the Table of Contents below, please provide a Title for this content, with max 15 words"
|
||||
]
|
||||
output = run_patterns(result["toc"], transformations=transformations)
|
||||
source.title = surreal_clean(output["output"])
|
||||
output = run_patterns(result, patterns=patterns)
|
||||
source.title = surreal_clean(output)
|
||||
source.save()
|
||||
return source
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Reference in a new issue