From 859b7f6e7ec47e54003341fd5b282dfd282c86f8 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Wed, 30 Oct 2024 14:30:29 -0300 Subject: [PATCH] simplify the model selector --- open_notebook/config.py | 13 +- open_notebook/graphs/utils.py | 6 +- open_notebook/models/__init__.py | 141 ++++++++++++++++++ open_notebook/models/embedding_models.py | 19 --- open_notebook/models/llms.py | 27 ---- open_notebook/models/speech_to_text_models.py | 20 --- 6 files changed, 151 insertions(+), 75 deletions(-) diff --git a/open_notebook/config.py b/open_notebook/config.py index 38a6972..9b15710 100644 --- a/open_notebook/config.py +++ b/open_notebook/config.py @@ -4,8 +4,7 @@ import yaml from loguru import logger from open_notebook.domain.models import DefaultModels -from open_notebook.models.embedding_models import get_embedding_model -from open_notebook.models.speech_to_text_models import get_speech_to_text_model +from open_notebook.models import get_model # todo: enable config file overwrite with env vars current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -40,8 +39,10 @@ os.makedirs(PODCASTS_FOLDER, exist_ok=True) DEFAULT_MODELS = DefaultModels.load() -EMBEDDING_MODEL = get_embedding_model(DEFAULT_MODELS.default_embedding_model) - -SPEECH_TO_TEXT_MODEL = get_speech_to_text_model( - DEFAULT_MODELS.default_speech_to_text_model +EMBEDDING_MODEL = get_model( + DEFAULT_MODELS.default_embedding_model, model_type="embedding" +) + +SPEECH_TO_TEXT_MODEL = get_model( + DEFAULT_MODELS.default_speech_to_text_model, model_type="speech_to_text" ) diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 9e95fcd..921afa5 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -1,7 +1,7 @@ from langchain.output_parsers import OutputFixingParser from open_notebook.config import DEFAULT_MODELS -from open_notebook.models.llms import get_langchain_model +from open_notebook.models import get_model from open_notebook.prompter import Prompter @@ -16,12 +16,12 @@ def run_pattern( if not model_name: model_name = DEFAULT_MODELS.default_transformation_model - chain = get_langchain_model(model_name) + chain = get_model(model_name, model_type="language") if parser: chain = chain | parser if output_fixing_model_name and parser: - output_fix_model = get_langchain_model(output_fixing_model_name) + output_fix_model = get_model(output_fixing_model_name, model_type="language") chain = chain | OutputFixingParser.from_llm( parser=parser, llm=output_fix_model, diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index e69de29..84907bc 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -0,0 +1,141 @@ +from open_notebook.domain.models import Model +from open_notebook.models.embedding_models import OpenAIEmbeddingModel +from open_notebook.models.llms import ( + AnthropicLanguageModel, + GeminiLanguageModel, + LiteLLMLanguageModel, + OllamaLanguageModel, + OpenAILanguageModel, + OpenRouterLanguageModel, + VertexAILanguageModel, + VertexAnthropicLanguageModel, +) +from open_notebook.models.speech_to_text_models import OpenAISpeechToTextModel + +# Unified model class map with type information +MODEL_CLASS_MAP = { + "language": { + "ollama": OllamaLanguageModel, + "openrouter": OpenRouterLanguageModel, + "vertexai-anthropic": VertexAnthropicLanguageModel, + "litellm": LiteLLMLanguageModel, + "vertexai": VertexAILanguageModel, + "anthropic": AnthropicLanguageModel, + "openai": OpenAILanguageModel, + "gemini": GeminiLanguageModel, + }, + "embedding": { + "openai": OpenAIEmbeddingModel, + }, + "speech_to_text": { + "openai": OpenAISpeechToTextModel, + }, +} + + +def get_model(model_id, model_type="language", **kwargs): + """ + Get a model instance based on model_id and type. + + Args: + model_id: The ID of the model to retrieve + model_type: Type of model ('language', 'embedding', or 'speech_to_text') + **kwargs: Additional arguments to pass to the model constructor + """ + assert model_id, "Model ID cannot be empty" + model = Model.get(model_id) + + if not model: + raise ValueError(f"Model with ID {model_id} not found") + + if model_type not in MODEL_CLASS_MAP: + raise ValueError(f"Invalid model type: {model_type}") + + provider_map = MODEL_CLASS_MAP[model_type] + if model.provider not in provider_map: + raise ValueError( + f"Provider {model.provider} not compatible with {model_type} models" + ) + + model_class = provider_map[model.provider] + model_instance = model_class(model_name=model.name, **kwargs) + + # Special handling for language models that need langchain conversion + if model_type == "language": + return model_instance.to_langchain() + + return model_instance + + +# from open_notebook.domain.models import Model +# from open_notebook.models.embedding_models import OpenAIEmbeddingModel +# from open_notebook.models.llms import ( +# AnthropicLanguageModel, +# GeminiLanguageModel, +# LiteLLMLanguageModel, +# OllamaLanguageModel, +# OpenAILanguageModel, +# OpenRouterLanguageModel, +# VertexAILanguageModel, +# VertexAnthropicLanguageModel, +# ) +# from open_notebook.models.speech_to_text_models import OpenAISpeechToTextModel + +# SPEECH_TO_TEXT_CLASS_MAP = { +# "openai": OpenAISpeechToTextModel, +# } + + +# # todo: acho que dá pra juntar todos os get models em uma coisa só +# def get_speech_to_text_model(model_id): +# assert model_id, "Model ID cannot be empty" +# model = Model.get(model_id) +# if not model: +# raise ValueError(f"Model with ID {model_id} not found") +# if model.provider not in SPEECH_TO_TEXT_CLASS_MAP.keys(): +# raise ValueError( +# f"Provider {model.provider} not compatible with Embedding Models" +# ) +# return SPEECH_TO_TEXT_CLASS_MAP[model.provider](model_name=model.name) + + +# # Map provider names to classes +# PROVIDER_CLASS_MAP = { +# "ollama": OllamaLanguageModel, +# "openrouter": OpenRouterLanguageModel, +# "vertexai-anthropic": VertexAnthropicLanguageModel, +# "litellm": LiteLLMLanguageModel, +# "vertexai": VertexAILanguageModel, +# "anthropic": AnthropicLanguageModel, +# "openai": OpenAILanguageModel, +# "gemini": GeminiLanguageModel, +# } + + +# # todo: make the provider check type specific +# def get_langchain_model(model_id, json=False): +# model = Model.get(model_id) +# if not model: +# raise ValueError(f"Model {model_id} not found") +# if model.provider not in PROVIDER_CLASS_MAP.keys(): +# raise ValueError(f"Provider {model.provider} not found") +# return PROVIDER_CLASS_MAP[model.provider]( +# model_name=model.name, json=json +# ).to_langchain() + + +# EMBEDDING_CLASS_MAP = { +# "openai": OpenAIEmbeddingModel, +# } + + +# def get_embedding_model(model_id): +# assert model_id, "Model ID cannot be empty" +# model = Model.get(model_id) +# if not model: +# raise ValueError(f"Model with ID {model_id} not found") +# if model.provider not in EMBEDDING_CLASS_MAP.keys(): +# raise ValueError( +# f"Provider {model.provider} not compatible with Embedding Models" +# ) +# return EMBEDDING_CLASS_MAP[model.provider](model_name=model.name) diff --git a/open_notebook/models/embedding_models.py b/open_notebook/models/embedding_models.py index aaaec65..ad8f141 100644 --- a/open_notebook/models/embedding_models.py +++ b/open_notebook/models/embedding_models.py @@ -8,8 +8,6 @@ from typing import List, Optional from openai import OpenAI -from open_notebook.domain.models import Model - @dataclass class EmbeddingModel(ABC): @@ -43,20 +41,3 @@ class OpenAIEmbeddingModel(EmbeddingModel): .data[0] .embedding ) - - -EMBEDDING_CLASS_MAP = { - "openai": OpenAIEmbeddingModel, -} - - -def get_embedding_model(model_id): - assert model_id, "Model ID cannot be empty" - model = Model.get(model_id) - if not model: - raise ValueError(f"Model with ID {model_id} not found") - if model.provider not in EMBEDDING_CLASS_MAP.keys(): - raise ValueError( - f"Provider {model.provider} not compatible with Embedding Models" - ) - return EMBEDDING_CLASS_MAP[model.provider](model_name=model.name) diff --git a/open_notebook/models/llms.py b/open_notebook/models/llms.py index 64c0213..ac95181 100644 --- a/open_notebook/models/llms.py +++ b/open_notebook/models/llms.py @@ -17,8 +17,6 @@ from langchain_ollama.chat_models import ChatOllama from langchain_openai.chat_models import ChatOpenAI from pydantic import SecretStr -from open_notebook.domain.models import Model - @dataclass class LanguageModel(ABC): @@ -238,28 +236,3 @@ class OpenAILanguageModel(LanguageModel): streaming=self.streaming, top_p=self.top_p, ) - - -# Map provider names to classes -PROVIDER_CLASS_MAP = { - "ollama": OllamaLanguageModel, - "openrouter": OpenRouterLanguageModel, - "vertexai-anthropic": VertexAnthropicLanguageModel, - "litellm": LiteLLMLanguageModel, - "vertexai": VertexAILanguageModel, - "anthropic": AnthropicLanguageModel, - "openai": OpenAILanguageModel, - "gemini": GeminiLanguageModel, -} - - -# todo: make the provider check type specific -def get_langchain_model(model_id, json=False): - model = Model.get(model_id) - if not model: - raise ValueError(f"Model {model_id} not found") - if model.provider not in PROVIDER_CLASS_MAP.keys(): - raise ValueError(f"Provider {model.provider} not found") - return PROVIDER_CLASS_MAP[model.provider]( - model_name=model.name, json=json - ).to_langchain() diff --git a/open_notebook/models/speech_to_text_models.py b/open_notebook/models/speech_to_text_models.py index 4191812..214bb45 100644 --- a/open_notebook/models/speech_to_text_models.py +++ b/open_notebook/models/speech_to_text_models.py @@ -8,8 +8,6 @@ from typing import Optional from openai import OpenAI -from open_notebook.domain.models import Model - @dataclass class SpeechToTextModel(ABC): @@ -42,21 +40,3 @@ class OpenAISpeechToTextModel(SpeechToTextModel): model=self.model_name, file=audio ) return transcription.text - - -SPEECH_TO_TEXT_CLASS_MAP = { - "openai": OpenAISpeechToTextModel, -} - - -# todo: acho que dá pra juntar todos os get models em uma coisa só -def get_speech_to_text_model(model_id): - assert model_id, "Model ID cannot be empty" - model = Model.get(model_id) - if not model: - raise ValueError(f"Model with ID {model_id} not found") - if model.provider not in SPEECH_TO_TEXT_CLASS_MAP.keys(): - raise ValueError( - f"Provider {model.provider} not compatible with Embedding Models" - ) - return SPEECH_TO_TEXT_CLASS_MAP[model.provider](model_name=model.name)