simplify the model selector
This commit is contained in:
parent
8bb5db158f
commit
859b7f6e7e
6 changed files with 151 additions and 75 deletions
|
|
@ -4,8 +4,7 @@ import yaml
|
|||
from loguru import logger
|
||||
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.models.embedding_models import get_embedding_model
|
||||
from open_notebook.models.speech_to_text_models import get_speech_to_text_model
|
||||
from open_notebook.models import get_model
|
||||
|
||||
# todo: enable config file overwrite with env vars
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
|
@ -40,8 +39,10 @@ os.makedirs(PODCASTS_FOLDER, exist_ok=True)
|
|||
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
|
||||
EMBEDDING_MODEL = get_embedding_model(DEFAULT_MODELS.default_embedding_model)
|
||||
|
||||
SPEECH_TO_TEXT_MODEL = get_speech_to_text_model(
|
||||
DEFAULT_MODELS.default_speech_to_text_model
|
||||
EMBEDDING_MODEL = get_model(
|
||||
DEFAULT_MODELS.default_embedding_model, model_type="embedding"
|
||||
)
|
||||
|
||||
SPEECH_TO_TEXT_MODEL = get_model(
|
||||
DEFAULT_MODELS.default_speech_to_text_model, model_type="speech_to_text"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langchain.output_parsers import OutputFixingParser
|
||||
|
||||
from open_notebook.config import DEFAULT_MODELS
|
||||
from open_notebook.models.llms import get_langchain_model
|
||||
from open_notebook.models import get_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
||||
|
|
@ -16,12 +16,12 @@ def run_pattern(
|
|||
if not model_name:
|
||||
model_name = DEFAULT_MODELS.default_transformation_model
|
||||
|
||||
chain = get_langchain_model(model_name)
|
||||
chain = get_model(model_name, model_type="language")
|
||||
if parser:
|
||||
chain = chain | parser
|
||||
|
||||
if output_fixing_model_name and parser:
|
||||
output_fix_model = get_langchain_model(output_fixing_model_name)
|
||||
output_fix_model = get_model(output_fixing_model_name, model_type="language")
|
||||
chain = chain | OutputFixingParser.from_llm(
|
||||
parser=parser,
|
||||
llm=output_fix_model,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,141 @@
|
|||
from open_notebook.domain.models import Model
|
||||
from open_notebook.models.embedding_models import OpenAIEmbeddingModel
|
||||
from open_notebook.models.llms import (
|
||||
AnthropicLanguageModel,
|
||||
GeminiLanguageModel,
|
||||
LiteLLMLanguageModel,
|
||||
OllamaLanguageModel,
|
||||
OpenAILanguageModel,
|
||||
OpenRouterLanguageModel,
|
||||
VertexAILanguageModel,
|
||||
VertexAnthropicLanguageModel,
|
||||
)
|
||||
from open_notebook.models.speech_to_text_models import OpenAISpeechToTextModel
|
||||
|
||||
# Unified model class map with type information
|
||||
MODEL_CLASS_MAP = {
|
||||
"language": {
|
||||
"ollama": OllamaLanguageModel,
|
||||
"openrouter": OpenRouterLanguageModel,
|
||||
"vertexai-anthropic": VertexAnthropicLanguageModel,
|
||||
"litellm": LiteLLMLanguageModel,
|
||||
"vertexai": VertexAILanguageModel,
|
||||
"anthropic": AnthropicLanguageModel,
|
||||
"openai": OpenAILanguageModel,
|
||||
"gemini": GeminiLanguageModel,
|
||||
},
|
||||
"embedding": {
|
||||
"openai": OpenAIEmbeddingModel,
|
||||
},
|
||||
"speech_to_text": {
|
||||
"openai": OpenAISpeechToTextModel,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_model(model_id, model_type="language", **kwargs):
|
||||
"""
|
||||
Get a model instance based on model_id and type.
|
||||
|
||||
Args:
|
||||
model_id: The ID of the model to retrieve
|
||||
model_type: Type of model ('language', 'embedding', or 'speech_to_text')
|
||||
**kwargs: Additional arguments to pass to the model constructor
|
||||
"""
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
model = Model.get(model_id)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
|
||||
if model_type not in MODEL_CLASS_MAP:
|
||||
raise ValueError(f"Invalid model type: {model_type}")
|
||||
|
||||
provider_map = MODEL_CLASS_MAP[model_type]
|
||||
if model.provider not in provider_map:
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with {model_type} models"
|
||||
)
|
||||
|
||||
model_class = provider_map[model.provider]
|
||||
model_instance = model_class(model_name=model.name, **kwargs)
|
||||
|
||||
# Special handling for language models that need langchain conversion
|
||||
if model_type == "language":
|
||||
return model_instance.to_langchain()
|
||||
|
||||
return model_instance
|
||||
|
||||
|
||||
# from open_notebook.domain.models import Model
|
||||
# from open_notebook.models.embedding_models import OpenAIEmbeddingModel
|
||||
# from open_notebook.models.llms import (
|
||||
# AnthropicLanguageModel,
|
||||
# GeminiLanguageModel,
|
||||
# LiteLLMLanguageModel,
|
||||
# OllamaLanguageModel,
|
||||
# OpenAILanguageModel,
|
||||
# OpenRouterLanguageModel,
|
||||
# VertexAILanguageModel,
|
||||
# VertexAnthropicLanguageModel,
|
||||
# )
|
||||
# from open_notebook.models.speech_to_text_models import OpenAISpeechToTextModel
|
||||
|
||||
# SPEECH_TO_TEXT_CLASS_MAP = {
|
||||
# "openai": OpenAISpeechToTextModel,
|
||||
# }
|
||||
|
||||
|
||||
# # todo: acho que dá pra juntar todos os get models em uma coisa só
|
||||
# def get_speech_to_text_model(model_id):
|
||||
# assert model_id, "Model ID cannot be empty"
|
||||
# model = Model.get(model_id)
|
||||
# if not model:
|
||||
# raise ValueError(f"Model with ID {model_id} not found")
|
||||
# if model.provider not in SPEECH_TO_TEXT_CLASS_MAP.keys():
|
||||
# raise ValueError(
|
||||
# f"Provider {model.provider} not compatible with Embedding Models"
|
||||
# )
|
||||
# return SPEECH_TO_TEXT_CLASS_MAP[model.provider](model_name=model.name)
|
||||
|
||||
|
||||
# # Map provider names to classes
|
||||
# PROVIDER_CLASS_MAP = {
|
||||
# "ollama": OllamaLanguageModel,
|
||||
# "openrouter": OpenRouterLanguageModel,
|
||||
# "vertexai-anthropic": VertexAnthropicLanguageModel,
|
||||
# "litellm": LiteLLMLanguageModel,
|
||||
# "vertexai": VertexAILanguageModel,
|
||||
# "anthropic": AnthropicLanguageModel,
|
||||
# "openai": OpenAILanguageModel,
|
||||
# "gemini": GeminiLanguageModel,
|
||||
# }
|
||||
|
||||
|
||||
# # todo: make the provider check type specific
|
||||
# def get_langchain_model(model_id, json=False):
|
||||
# model = Model.get(model_id)
|
||||
# if not model:
|
||||
# raise ValueError(f"Model {model_id} not found")
|
||||
# if model.provider not in PROVIDER_CLASS_MAP.keys():
|
||||
# raise ValueError(f"Provider {model.provider} not found")
|
||||
# return PROVIDER_CLASS_MAP[model.provider](
|
||||
# model_name=model.name, json=json
|
||||
# ).to_langchain()
|
||||
|
||||
|
||||
# EMBEDDING_CLASS_MAP = {
|
||||
# "openai": OpenAIEmbeddingModel,
|
||||
# }
|
||||
|
||||
|
||||
# def get_embedding_model(model_id):
|
||||
# assert model_id, "Model ID cannot be empty"
|
||||
# model = Model.get(model_id)
|
||||
# if not model:
|
||||
# raise ValueError(f"Model with ID {model_id} not found")
|
||||
# if model.provider not in EMBEDDING_CLASS_MAP.keys():
|
||||
# raise ValueError(
|
||||
# f"Provider {model.provider} not compatible with Embedding Models"
|
||||
# )
|
||||
# return EMBEDDING_CLASS_MAP[model.provider](model_name=model.name)
|
||||
|
|
@ -8,8 +8,6 @@ from typing import List, Optional
|
|||
|
||||
from openai import OpenAI
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModel(ABC):
|
||||
|
|
@ -43,20 +41,3 @@ class OpenAIEmbeddingModel(EmbeddingModel):
|
|||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
|
||||
|
||||
EMBEDDING_CLASS_MAP = {
|
||||
"openai": OpenAIEmbeddingModel,
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model(model_id):
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
model = Model.get(model_id)
|
||||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
if model.provider not in EMBEDDING_CLASS_MAP.keys():
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with Embedding Models"
|
||||
)
|
||||
return EMBEDDING_CLASS_MAP[model.provider](model_name=model.name)
|
||||
|
|
|
|||
|
|
@ -17,8 +17,6 @@ from langchain_ollama.chat_models import ChatOllama
|
|||
from langchain_openai.chat_models import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageModel(ABC):
|
||||
|
|
@ -238,28 +236,3 @@ class OpenAILanguageModel(LanguageModel):
|
|||
streaming=self.streaming,
|
||||
top_p=self.top_p,
|
||||
)
|
||||
|
||||
|
||||
# Map provider names to classes
|
||||
PROVIDER_CLASS_MAP = {
|
||||
"ollama": OllamaLanguageModel,
|
||||
"openrouter": OpenRouterLanguageModel,
|
||||
"vertexai-anthropic": VertexAnthropicLanguageModel,
|
||||
"litellm": LiteLLMLanguageModel,
|
||||
"vertexai": VertexAILanguageModel,
|
||||
"anthropic": AnthropicLanguageModel,
|
||||
"openai": OpenAILanguageModel,
|
||||
"gemini": GeminiLanguageModel,
|
||||
}
|
||||
|
||||
|
||||
# todo: make the provider check type specific
|
||||
def get_langchain_model(model_id, json=False):
|
||||
model = Model.get(model_id)
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
if model.provider not in PROVIDER_CLASS_MAP.keys():
|
||||
raise ValueError(f"Provider {model.provider} not found")
|
||||
return PROVIDER_CLASS_MAP[model.provider](
|
||||
model_name=model.name, json=json
|
||||
).to_langchain()
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ from typing import Optional
|
|||
|
||||
from openai import OpenAI
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechToTextModel(ABC):
|
||||
|
|
@ -42,21 +40,3 @@ class OpenAISpeechToTextModel(SpeechToTextModel):
|
|||
model=self.model_name, file=audio
|
||||
)
|
||||
return transcription.text
|
||||
|
||||
|
||||
SPEECH_TO_TEXT_CLASS_MAP = {
|
||||
"openai": OpenAISpeechToTextModel,
|
||||
}
|
||||
|
||||
|
||||
# todo: acho que dá pra juntar todos os get models em uma coisa só
|
||||
def get_speech_to_text_model(model_id):
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
model = Model.get(model_id)
|
||||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
if model.provider not in SPEECH_TO_TEXT_CLASS_MAP.keys():
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with Embedding Models"
|
||||
)
|
||||
return SPEECH_TO_TEXT_CLASS_MAP[model.provider](model_name=model.name)
|
||||
|
|
|
|||
Loading…
Reference in a new issue