From feabfaed0171ab008d29dcc6f898e94b29164b8c Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 19:56:27 -0300 Subject: [PATCH] remove defaultmodel from config file --- open_notebook/config.py | 22 ---------- open_notebook/database/migrate.py | 12 +++++- open_notebook/domain/base.py | 29 +++++++++++-- open_notebook/domain/models.py | 42 ++++++++++--------- open_notebook/domain/notebook.py | 11 +++-- open_notebook/graphs/chat.py | 5 ++- .../graphs/content_processing/audio.py | 6 ++- open_notebook/graphs/multipattern.py | 4 +- open_notebook/graphs/utils.py | 8 ++-- open_notebook/models/__init__.py | 13 +++--- pages/2_📒_Notebooks.py | 4 -- stream_app/chat.py | 1 - stream_app/source.py | 12 ++---- 13 files changed, 89 insertions(+), 80 deletions(-) diff --git a/open_notebook/config.py b/open_notebook/config.py index c07ab03..096850c 100644 --- a/open_notebook/config.py +++ b/open_notebook/config.py @@ -3,9 +3,6 @@ import os import yaml from loguru import logger -from open_notebook.domain.models import DefaultModels -from open_notebook.models import get_model - current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(current_dir) config_path = os.path.join(project_root, "open_notebook_config.yaml") @@ -34,22 +31,3 @@ os.makedirs(UPLOADS_FOLDER, exist_ok=True) # PODCASTS FOLDER PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts" os.makedirs(PODCASTS_FOLDER, exist_ok=True) - - -def load_default_models(): - default_models = DefaultModels.load() - embedding_model = ( - get_model(default_models.default_embedding_model, model_type="embedding") - if default_models.default_embedding_model - else None - ) - - speech_to_text_model = ( - get_model( - default_models.default_speech_to_text_model, model_type="speech_to_text" - ) - if default_models.default_speech_to_text_model - else None - ) - - return default_models, embedding_model, speech_to_text_model diff --git a/open_notebook/database/migrate.py b/open_notebook/database/migrate.py index 7d99fbd..f890091 100644 --- a/open_notebook/database/migrate.py +++ b/open_notebook/database/migrate.py @@ -18,8 +18,16 @@ class MigrationManager: database=os.environ["SURREAL_DATABASE"], encrypted=False, # Set to True if using SSL ) - self.up_migrations = [Migration.from_file("migrations/1.surrealql")] - self.down_migrations = [Migration.from_file("migrations/1_down.surrealql")] + self.up_migrations = [ + Migration.from_file("migrations/1.surrealql"), + Migration.from_file("migrations/2.surrealql"), + ] + self.down_migrations = [ + Migration.from_file( + "migrations/1_down.surrealql", + ), + Migration.from_file("migrations/2_down.surrealql"), + ] self.runner = MigrationRunner( up_migrations=self.up_migrations, down_migrations=self.down_migrations, diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index 2aa648e..76b1a32 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -68,9 +68,11 @@ class ObjectModel(BaseModel): return None def save(self) -> None: - from open_notebook.config import load_default_models + from open_notebook.domain.models import DefaultModels + from open_notebook.models import get_model - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() + EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) try: logger.debug(f"Validating {self.__class__.__name__}") @@ -88,7 +90,11 @@ class ObjectModel(BaseModel): logger.debug("Creating new record") repo_result = repo_create(self.__class__.table_name, data) else: - data["created"] = self.created.strftime("%Y-%m-%d %H:%M:%S") + data["created"] = ( + self.created.strftime("%Y-%m-%d %H:%M:%S") + if type(self.created) == datetime + else self.created + ) logger.debug(f"Updating record with id {self.id}") repo_result = repo_update(self.id, data) @@ -148,3 +154,20 @@ class ObjectModel(BaseModel): if isinstance(value, str): return datetime.fromisoformat(value.replace("Z", "+00:00")) return value + + +class RecordModel(BaseModel): + record_id: ClassVar[str] = "open_notebook:default_models" + + @classmethod + def load(cls): + result = repo_query(f"SELECT * FROM {cls.record_id};") + if result: + result = result[0] + dm = cls(**result) + return dm + return cls() + + @classmethod + def update(self, data): + repo_update(self.record_id, data) diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index 5699147..de65c53 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -1,12 +1,7 @@ from typing import ClassVar, Optional -from pydantic import BaseModel - -from open_notebook.database.repository import ( - repo_query, - repo_update, -) -from open_notebook.domain.base import ObjectModel +from open_notebook.database.repository import repo_query +from open_notebook.domain.base import ObjectModel, RecordModel class Model(ObjectModel): @@ -23,7 +18,10 @@ class Model(ObjectModel): return [Model(**model) for model in models] -class DefaultModels(BaseModel): +# todo: future: colocar um cache aqui +class DefaultModels(RecordModel): + record_id: ClassVar[str] = "open_notebook:default_models" + default_chat_model: Optional[str] = None default_transformation_model: Optional[str] = None large_context_model: Optional[str] = None @@ -32,15 +30,21 @@ class DefaultModels(BaseModel): # default_vision_model: Optional[str] = None default_embedding_model: Optional[str] = None - @classmethod - def load(self): - result = repo_query("SELECT * FROM open_notebook:default_models;") - if result: - result = result[0] - dm = DefaultModels(**result) - return dm - return DefaultModels() - @classmethod - def update(self, data): - repo_update("open_notebook:default_models", data) +# def load_default_models(): +# default_models = DefaultModels.load() +# embedding_model = ( +# get_model(default_models.default_embedding_model, model_type="embedding") +# if default_models.default_embedding_model +# else None +# ) + +# speech_to_text_model = ( +# get_model( +# default_models.default_speech_to_text_model, model_type="speech_to_text" +# ) +# if default_models.default_speech_to_text_model +# else None +# ) + +# return default_models, embedding_model, speech_to_text_model diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index 4321231..d64ee08 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -4,18 +4,19 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional from loguru import logger from pydantic import BaseModel, Field, field_validator -from open_notebook.config import load_default_models from open_notebook.database.repository import ( repo_create, repo_query, ) from open_notebook.domain.base import ObjectModel +from open_notebook.domain.models import DefaultModels from open_notebook.exceptions import ( DatabaseOperationError, InvalidInputError, ) # from temp.recursive_toc import graph as toc_graph +from open_notebook.models import get_model from open_notebook.utils import split_text, surreal_clean @@ -139,7 +140,8 @@ class Source(ObjectModel): raise DatabaseOperationError(e) def vectorize(self) -> None: - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() + EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) try: if not self.full_text: @@ -190,7 +192,8 @@ class Source(ObjectModel): raise DatabaseOperationError("Failed to search sources") def add_insight(self, insight_type: str, content: str) -> Any: - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() + EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) if not insight_type or not content: raise InvalidInputError("Insight type and content must be provided") @@ -214,7 +217,7 @@ class Source(ObjectModel): class Note(ObjectModel): table_name: ClassVar[str] = "note" title: Optional[str] = None - note_type: Optional[Literal["human", "ai"]] = "human" + note_type: Optional[Literal["human", "ai"]] = None content: Optional[str] = None @field_validator("content") diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 7735cd8..5d75151 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -9,11 +9,12 @@ from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from typing_extensions import TypedDict -from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE, load_default_models +from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE +from open_notebook.domain.models import DefaultModels from open_notebook.domain.notebook import Notebook from open_notebook.graphs.utils import run_pattern -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() +DEFAULT_MODELS = DefaultModels.load() class ThreadState(TypedDict): diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index b7c31be..ad2d7d7 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -4,8 +4,9 @@ from math import ceil from loguru import logger from pydub import AudioSegment -from open_notebook.config import load_default_models +from open_notebook.domain.models import DefaultModels from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.models import get_model # future: parallelize the transcription process @@ -72,7 +73,8 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None): def extract_audio(data: SourceState): - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() + SPEECH_TO_TEXT_MODEL = get_model(DEFAULT_MODELS.default_speech_to_text_model) input_audio_path = data.get("file_path") audio_files = [] diff --git a/open_notebook/graphs/multipattern.py b/open_notebook/graphs/multipattern.py index 75d499a..17febca 100644 --- a/open_notebook/graphs/multipattern.py +++ b/open_notebook/graphs/multipattern.py @@ -7,10 +7,10 @@ from langchain_core.runnables import ( from langgraph.graph import END, START, StateGraph from typing_extensions import Annotated, TypedDict -from open_notebook.config import load_default_models +from open_notebook.domain.models import DefaultModels from open_notebook.graphs.utils import run_pattern -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() +DEFAULT_MODELS = DefaultModels.load() class PatternChainState(TypedDict): diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 39871ee..0c0e137 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -1,7 +1,7 @@ from langchain.output_parsers import OutputFixingParser from loguru import logger -from open_notebook.config import load_default_models +from open_notebook.domain.models import DefaultModels from open_notebook.models import get_model from open_notebook.prompter import Prompter from open_notebook.utils import token_count @@ -18,7 +18,7 @@ def run_pattern( system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( data=state ) - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() tokens = token_count(str(system_prompt) + str(messages)) if tokens > 105_000: @@ -33,12 +33,12 @@ def run_pattern( or DEFAULT_MODELS.default_chat_model ) - chain = get_model(model_id, model_type="language") + chain = get_model(model_id) if parser: chain = chain | parser if output_fixing_model_id and parser: - output_fix_model = get_model(output_fixing_model_id, model_type="language") + output_fix_model = get_model(output_fixing_model_id) chain = chain | OutputFixingParser.from_llm( parser=parser, llm=output_fix_model, diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index 9473c9e..8cd9067 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -49,13 +49,12 @@ MODEL_CLASS_MAP = { } -def get_model(model_id, model_type="language", **kwargs): +def get_model(model_id, **kwargs): """ Get a model instance based on model_id and type. Args: model_id: The ID of the model to retrieve - model_type: Type of model ('language', 'embedding', or 'speech_to_text') **kwargs: Additional arguments to pass to the model constructor """ assert model_id, "Model ID cannot be empty" @@ -64,20 +63,20 @@ def get_model(model_id, model_type="language", **kwargs): if not model: raise ValueError(f"Model with ID {model_id} not found") - if model_type not in MODEL_CLASS_MAP: - raise ValueError(f"Invalid model type: {model_type}") + if not model.type or model.type not in MODEL_CLASS_MAP: + raise ValueError(f"Invalid model type: {model.type}") - provider_map = MODEL_CLASS_MAP[model_type] + provider_map = MODEL_CLASS_MAP[model.type] if model.provider not in provider_map: raise ValueError( - f"Provider {model.provider} not compatible with {model_type} models" + f"Provider {model.provider} not compatible with {model.type} models" ) model_class = provider_map[model.provider] model_instance = model_class(model_name=model.name, **kwargs) # Special handling for language models that need langchain conversion - if model_type == "language": + if model.type == "language": return model_instance.to_langchain() return model_instance diff --git a/pages/2_📒_Notebooks.py b/pages/2_📒_Notebooks.py index b0a6c07..e6302a3 100644 --- a/pages/2_📒_Notebooks.py +++ b/pages/2_📒_Notebooks.py @@ -1,7 +1,6 @@ import streamlit as st from humanize import naturaltime -from open_notebook.config import load_default_models from open_notebook.domain.notebook import Notebook from stream_app.chat import chat_sidebar from stream_app.note import add_note, note_card @@ -71,9 +70,6 @@ def notebook_page(current_notebook_id): sources = current_notebook.sources notes = current_notebook.notes - # Load the default models dynamically - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - notebook_header(current_notebook) work_tab, chat_tab = st.columns([4, 2]) diff --git a/stream_app/chat.py b/stream_app/chat.py index c883780..00d8bd5 100644 --- a/stream_app/chat.py +++ b/stream_app/chat.py @@ -111,6 +111,5 @@ def chat_sidebar(session_id): make_note_from_chat( content=msg.content, notebook_id=st.session_state[session_id]["notebook"].id, - type="ai", ) st.rerun() diff --git a/stream_app/source.py b/stream_app/source.py index 64fd54f..25f880b 100644 --- a/stream_app/source.py +++ b/stream_app/source.py @@ -7,7 +7,7 @@ import yaml from humanize import naturaltime from loguru import logger -from open_notebook.config import UPLOADS_FOLDER, load_default_models +from open_notebook.config import UPLOADS_FOLDER from open_notebook.domain.notebook import Asset, Source from open_notebook.exceptions import UnsupportedTypeException from open_notebook.graphs.content_processing import graph @@ -16,8 +16,6 @@ from open_notebook.utils import surreal_clean from .consts import context_icons -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - def run_patterns(input_text, patterns): output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns)) @@ -26,18 +24,16 @@ def run_patterns(input_text, patterns): # moved it here to replace it with the pipeline on 0.1.0 def generate_toc_and_title(source) -> "Source": - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - try: patterns = ["patterns/default/toc"] result = run_patterns(source.full_text, patterns=patterns) source.add_insight("Table of Contents", surreal_clean(result)) if not source.title: - transformations = [ + patterns = [ "Based on the Table of Contents below, please provide a Title for this content, with max 15 words" ] - output = run_patterns(result["toc"], transformations=transformations) - source.title = surreal_clean(output["output"]) + output = run_patterns(result, patterns=patterns) + source.title = surreal_clean(output) source.save() return source except Exception as e: