diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index 76b1a32..84a57c9 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -68,11 +68,9 @@ class ObjectModel(BaseModel): return None def save(self) -> None: - from open_notebook.domain.models import DefaultModels - from open_notebook.models import get_model + from open_notebook.models import model_manager - DEFAULT_MODELS = DefaultModels.load() - EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) + EMBEDDING_MODEL = model_manager.get_default_model("embedding") try: logger.debug(f"Validating {self.__class__.__name__}") diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index d64ee08..c15e8e2 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -9,14 +9,13 @@ from open_notebook.database.repository import ( repo_query, ) from open_notebook.domain.base import ObjectModel -from open_notebook.domain.models import DefaultModels from open_notebook.exceptions import ( DatabaseOperationError, InvalidInputError, ) # from temp.recursive_toc import graph as toc_graph -from open_notebook.models import get_model +from open_notebook.models import model_manager from open_notebook.utils import split_text, surreal_clean @@ -140,8 +139,7 @@ class Source(ObjectModel): raise DatabaseOperationError(e) def vectorize(self) -> None: - DEFAULT_MODELS = DefaultModels.load() - EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) + EMBEDDING_MODEL = model_manager.get_default_model("embedding") try: if not self.full_text: @@ -192,8 +190,7 @@ class Source(ObjectModel): raise DatabaseOperationError("Failed to search sources") def add_insight(self, insight_type: str, content: str) -> Any: - DEFAULT_MODELS = DefaultModels.load() - EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) + EMBEDDING_MODEL = model_manager.get_default_model("embedding") if not insight_type or not content: raise InvalidInputError("Insight type and content must be provided") diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index ad2d7d7..ac81481 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -4,9 +4,8 @@ from math import ceil from loguru import logger from pydub import AudioSegment -from open_notebook.domain.models import DefaultModels from open_notebook.graphs.content_processing.state import SourceState -from open_notebook.models import get_model +from open_notebook.models import model_manager # future: parallelize the transcription process @@ -73,8 +72,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None): def extract_audio(data: SourceState): - DEFAULT_MODELS = DefaultModels.load() - SPEECH_TO_TEXT_MODEL = get_model(DEFAULT_MODELS.default_speech_to_text_model) + SPEECH_TO_TEXT_MODEL = model_manager.get_default_model("speech_to_text") input_audio_path = data.get("file_path") audio_files = [] diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 0c0e137..4b4a896 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -2,7 +2,7 @@ from langchain.output_parsers import OutputFixingParser from loguru import logger from open_notebook.domain.models import DefaultModels -from open_notebook.models import get_model +from open_notebook.models import model_manager from open_notebook.prompter import Prompter from open_notebook.utils import token_count @@ -33,12 +33,12 @@ def run_pattern( or DEFAULT_MODELS.default_chat_model ) - chain = get_model(model_id) + chain = model_manager.get_default_model("transformation") if parser: chain = chain | parser if output_fixing_model_id and parser: - output_fix_model = get_model(output_fixing_model_id) + output_fix_model = model_manager.get_model(output_fixing_model_id) chain = chain | OutputFixingParser.from_llm( parser=parser, llm=output_fix_model, diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index 8cd9067..23623bb 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -1,4 +1,6 @@ -from open_notebook.domain.models import Model +from typing import Dict, Optional + +from open_notebook.domain.models import DefaultModels, Model from open_notebook.models.embedding_models import ( GeminiEmbeddingModel, OllamaEmbeddingModel, @@ -49,34 +51,137 @@ MODEL_CLASS_MAP = { } -def get_model(model_id, **kwargs): - """ - Get a model instance based on model_id and type. +# def get_model(model_id, **kwargs): +# """ +# Get a model instance based on model_id and type. - Args: - model_id: The ID of the model to retrieve - **kwargs: Additional arguments to pass to the model constructor - """ - assert model_id, "Model ID cannot be empty" - model: Model = Model.get(model_id) +# Args: +# model_id: The ID of the model to retrieve +# **kwargs: Additional arguments to pass to the model constructor +# """ +# assert model_id, "Model ID cannot be empty" +# model: Model = Model.get(model_id) - if not model: - raise ValueError(f"Model with ID {model_id} not found") +# if not model: +# raise ValueError(f"Model with ID {model_id} not found") - if not model.type or model.type not in MODEL_CLASS_MAP: - raise ValueError(f"Invalid model type: {model.type}") +# if not model.type or model.type not in MODEL_CLASS_MAP: +# raise ValueError(f"Invalid model type: {model.type}") - provider_map = MODEL_CLASS_MAP[model.type] - if model.provider not in provider_map: - raise ValueError( - f"Provider {model.provider} not compatible with {model.type} models" - ) +# provider_map = MODEL_CLASS_MAP[model.type] +# if model.provider not in provider_map: +# raise ValueError( +# f"Provider {model.provider} not compatible with {model.type} models" +# ) - model_class = provider_map[model.provider] - model_instance = model_class(model_name=model.name, **kwargs) +# model_class = provider_map[model.provider] +# model_instance = model_class(model_name=model.name, **kwargs) - # Special handling for language models that need langchain conversion - if model.type == "language": - return model_instance.to_langchain() +# # Special handling for language models that need langchain conversion +# if model.type == "language": +# return model_instance.to_langchain() - return model_instance +# return model_instance + + +class ModelManager: + _instance = None + _model_cache: Dict[str, object] = {} + _default_models: Optional[DefaultModels] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ModelManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, "_initialized"): + self._initialized = True + self.refresh_defaults() + + def refresh_defaults(self): + """Refresh the default models from the database""" + self._default_models = DefaultModels.load() + + @property + def defaults(self) -> DefaultModels: + """Get the default models configuration""" + if not self._default_models: + self.refresh_defaults() + return self._default_models + + def get_model(self, model_id: str, **kwargs) -> object: + """ + Get a model instance based on model_id. Uses caching to avoid recreating instances. + + Args: + model_id: The ID of the model to retrieve + **kwargs: Additional arguments to pass to the model constructor + """ + cache_key = f"{model_id}:{str(kwargs)}" + + if cache_key in self._model_cache: + return self._model_cache[cache_key] + + assert model_id, "Model ID cannot be empty" + model: Model = Model.get(model_id) + + if not model: + raise ValueError(f"Model with ID {model_id} not found") + + if not model.type or model.type not in MODEL_CLASS_MAP: + raise ValueError(f"Invalid model type: {model.type}") + + provider_map = MODEL_CLASS_MAP[model.type] + if model.provider not in provider_map: + raise ValueError( + f"Provider {model.provider} not compatible with {model.type} models" + ) + + model_class = provider_map[model.provider] + model_instance = model_class(model_name=model.name, **kwargs) + + # Special handling for language models that need langchain conversion + if model.type == "language": + model_instance = model_instance.to_langchain() + + self._model_cache[cache_key] = model_instance + return model_instance + + def get_default_model(self, model_type: str, **kwargs) -> object: + """ + Get the default model for a specific type. + + Args: + model_type: The type of model to retrieve (e.g., 'chat', 'embedding', etc.) + **kwargs: Additional arguments to pass to the model constructor + """ + model_id = None + + if model_type == "chat": + model_id = self.defaults.default_chat_model + elif model_type == "transformation": + model_id = ( + self.defaults.default_transformation_model + or self.defaults.default_chat_model + ) + elif model_type == "embedding": + model_id = self.defaults.default_embedding_model + elif model_type == "text_to_speech": + model_id = self.defaults.default_text_to_speech_model + elif model_type == "speech_to_text": + model_id = self.defaults.default_speech_to_text_model + elif model_type == "large_context": + model_id = self.defaults.large_context_model + + if not model_id: + raise ValueError(f"No default model configured for type: {model_type}") + + return self.get_model(model_id, **kwargs) + + def clear_cache(self): + """Clear the model cache""" + self._model_cache.clear() + + +model_manager = ModelManager() diff --git a/prompts/patterns/default/toc.jinja b/prompts/patterns/default/toc.jinja index c78f159..23b84f0 100644 --- a/prompts/patterns/default/toc.jinja +++ b/prompts/patterns/default/toc.jinja @@ -10,6 +10,6 @@ Analyze the provided content and create a Table of Contents: # INPUT -{{content}} +{{input_text}} # OUTPUT \ No newline at end of file