model manager

This commit is contained in:
LUIS NOVO 2024-11-01 20:37:23 -03:00
parent 3b7dd5f25f
commit a9ac4a6dc8
6 changed files with 141 additions and 43 deletions

View file

@ -68,11 +68,9 @@ class ObjectModel(BaseModel):
return None
def save(self) -> None:
from open_notebook.domain.models import DefaultModels
from open_notebook.models import get_model
from open_notebook.models import model_manager
DEFAULT_MODELS = DefaultModels.load()
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
EMBEDDING_MODEL = model_manager.get_default_model("embedding")
try:
logger.debug(f"Validating {self.__class__.__name__}")

View file

@ -9,14 +9,13 @@ from open_notebook.database.repository import (
repo_query,
)
from open_notebook.domain.base import ObjectModel
from open_notebook.domain.models import DefaultModels
from open_notebook.exceptions import (
DatabaseOperationError,
InvalidInputError,
)
# from temp.recursive_toc import graph as toc_graph
from open_notebook.models import get_model
from open_notebook.models import model_manager
from open_notebook.utils import split_text, surreal_clean
@ -140,8 +139,7 @@ class Source(ObjectModel):
raise DatabaseOperationError(e)
def vectorize(self) -> None:
DEFAULT_MODELS = DefaultModels.load()
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
EMBEDDING_MODEL = model_manager.get_default_model("embedding")
try:
if not self.full_text:
@ -192,8 +190,7 @@ class Source(ObjectModel):
raise DatabaseOperationError("Failed to search sources")
def add_insight(self, insight_type: str, content: str) -> Any:
DEFAULT_MODELS = DefaultModels.load()
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
EMBEDDING_MODEL = model_manager.get_default_model("embedding")
if not insight_type or not content:
raise InvalidInputError("Insight type and content must be provided")

View file

@ -4,9 +4,8 @@ from math import ceil
from loguru import logger
from pydub import AudioSegment
from open_notebook.domain.models import DefaultModels
from open_notebook.graphs.content_processing.state import SourceState
from open_notebook.models import get_model
from open_notebook.models import model_manager
# future: parallelize the transcription process
@ -73,8 +72,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
def extract_audio(data: SourceState):
DEFAULT_MODELS = DefaultModels.load()
SPEECH_TO_TEXT_MODEL = get_model(DEFAULT_MODELS.default_speech_to_text_model)
SPEECH_TO_TEXT_MODEL = model_manager.get_default_model("speech_to_text")
input_audio_path = data.get("file_path")
audio_files = []

View file

@ -2,7 +2,7 @@ from langchain.output_parsers import OutputFixingParser
from loguru import logger
from open_notebook.domain.models import DefaultModels
from open_notebook.models import get_model
from open_notebook.models import model_manager
from open_notebook.prompter import Prompter
from open_notebook.utils import token_count
@ -33,12 +33,12 @@ def run_pattern(
or DEFAULT_MODELS.default_chat_model
)
chain = get_model(model_id)
chain = model_manager.get_default_model("transformation")
if parser:
chain = chain | parser
if output_fixing_model_id and parser:
output_fix_model = get_model(output_fixing_model_id)
output_fix_model = model_manager.get_model(output_fixing_model_id)
chain = chain | OutputFixingParser.from_llm(
parser=parser,
llm=output_fix_model,

View file

@ -1,4 +1,6 @@
from open_notebook.domain.models import Model
from typing import Dict, Optional
from open_notebook.domain.models import DefaultModels, Model
from open_notebook.models.embedding_models import (
GeminiEmbeddingModel,
OllamaEmbeddingModel,
@ -49,34 +51,137 @@ MODEL_CLASS_MAP = {
}
def get_model(model_id, **kwargs):
"""
Get a model instance based on model_id and type.
# def get_model(model_id, **kwargs):
# """
# Get a model instance based on model_id and type.
Args:
model_id: The ID of the model to retrieve
**kwargs: Additional arguments to pass to the model constructor
"""
assert model_id, "Model ID cannot be empty"
model: Model = Model.get(model_id)
# Args:
# model_id: The ID of the model to retrieve
# **kwargs: Additional arguments to pass to the model constructor
# """
# assert model_id, "Model ID cannot be empty"
# model: Model = Model.get(model_id)
if not model:
raise ValueError(f"Model with ID {model_id} not found")
# if not model:
# raise ValueError(f"Model with ID {model_id} not found")
if not model.type or model.type not in MODEL_CLASS_MAP:
raise ValueError(f"Invalid model type: {model.type}")
# if not model.type or model.type not in MODEL_CLASS_MAP:
# raise ValueError(f"Invalid model type: {model.type}")
provider_map = MODEL_CLASS_MAP[model.type]
if model.provider not in provider_map:
raise ValueError(
f"Provider {model.provider} not compatible with {model.type} models"
)
# provider_map = MODEL_CLASS_MAP[model.type]
# if model.provider not in provider_map:
# raise ValueError(
# f"Provider {model.provider} not compatible with {model.type} models"
# )
model_class = provider_map[model.provider]
model_instance = model_class(model_name=model.name, **kwargs)
# model_class = provider_map[model.provider]
# model_instance = model_class(model_name=model.name, **kwargs)
# Special handling for language models that need langchain conversion
if model.type == "language":
return model_instance.to_langchain()
# # Special handling for language models that need langchain conversion
# if model.type == "language":
# return model_instance.to_langchain()
return model_instance
# return model_instance
class ModelManager:
_instance = None
_model_cache: Dict[str, object] = {}
_default_models: Optional[DefaultModels] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, "_initialized"):
self._initialized = True
self.refresh_defaults()
def refresh_defaults(self):
"""Refresh the default models from the database"""
self._default_models = DefaultModels.load()
@property
def defaults(self) -> DefaultModels:
"""Get the default models configuration"""
if not self._default_models:
self.refresh_defaults()
return self._default_models
def get_model(self, model_id: str, **kwargs) -> object:
"""
Get a model instance based on model_id. Uses caching to avoid recreating instances.
Args:
model_id: The ID of the model to retrieve
**kwargs: Additional arguments to pass to the model constructor
"""
cache_key = f"{model_id}:{str(kwargs)}"
if cache_key in self._model_cache:
return self._model_cache[cache_key]
assert model_id, "Model ID cannot be empty"
model: Model = Model.get(model_id)
if not model:
raise ValueError(f"Model with ID {model_id} not found")
if not model.type or model.type not in MODEL_CLASS_MAP:
raise ValueError(f"Invalid model type: {model.type}")
provider_map = MODEL_CLASS_MAP[model.type]
if model.provider not in provider_map:
raise ValueError(
f"Provider {model.provider} not compatible with {model.type} models"
)
model_class = provider_map[model.provider]
model_instance = model_class(model_name=model.name, **kwargs)
# Special handling for language models that need langchain conversion
if model.type == "language":
model_instance = model_instance.to_langchain()
self._model_cache[cache_key] = model_instance
return model_instance
def get_default_model(self, model_type: str, **kwargs) -> object:
"""
Get the default model for a specific type.
Args:
model_type: The type of model to retrieve (e.g., 'chat', 'embedding', etc.)
**kwargs: Additional arguments to pass to the model constructor
"""
model_id = None
if model_type == "chat":
model_id = self.defaults.default_chat_model
elif model_type == "transformation":
model_id = (
self.defaults.default_transformation_model
or self.defaults.default_chat_model
)
elif model_type == "embedding":
model_id = self.defaults.default_embedding_model
elif model_type == "text_to_speech":
model_id = self.defaults.default_text_to_speech_model
elif model_type == "speech_to_text":
model_id = self.defaults.default_speech_to_text_model
elif model_type == "large_context":
model_id = self.defaults.large_context_model
if not model_id:
raise ValueError(f"No default model configured for type: {model_type}")
return self.get_model(model_id, **kwargs)
def clear_cache(self):
"""Clear the model cache"""
self._model_cache.clear()
model_manager = ModelManager()

View file

@ -10,6 +10,6 @@ Analyze the provided content and create a Table of Contents:
# INPUT
{{content}}
{{input_text}}
# OUTPUT