improve default_models

This commit is contained in:
LUIS NOVO 2024-11-01 22:38:21 -03:00
parent 4f586ad513
commit 223f1bdaf5
4 changed files with 20 additions and 5 deletions

View file

@ -105,6 +105,22 @@ class ModelManager:
raise RuntimeError("Failed to initialize default models configuration")
return self._default_models
@property
def speech_to_text(self, **kwargs) -> SpeechToTextModel:
"""Get the default speech-to-text model"""
model = self.get_default_model("speech_to_text", **kwargs)
if not isinstance(model, SpeechToTextModel):
raise TypeError(f"Expected SpeechToTextModel but got {type(model)}")
return model
@property
def text_to_speech(self, **kwargs) -> TextToSpeechModel:
"""Get the default text-to-speech model"""
model = self.get_default_model("text_to_speech", **kwargs)
if not isinstance(model, TextToSpeechModel):
raise TypeError(f"Expected TextToSpeechModel but got {type(model)}")
return model
@property
def embedding_model(self, **kwargs) -> EmbeddingModel:
"""Get the default embedding model"""

View file

@ -73,7 +73,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
def extract_audio(data: SourceState):
SPEECH_TO_TEXT_MODEL = model_manager.get_default_model("speech_to_text")
SPEECH_TO_TEXT_MODEL = model_manager.speech_to_text
input_audio_path = data.get("file_path")
audio_files = []

View file

@ -20,11 +20,11 @@ def provision_model(content, config, default_type):
logger.debug(
f"Using large context model because the content has {tokens} tokens"
)
return model_manager.get_default_model("large_context")
return model_manager.get_default_model("large_context").to_langchain()
elif config.get("configurable", {}).get("model_id"):
return model_manager.get_model(config.get("configurable", {}).get("model_id"))
else:
return model_manager.get_default_model(default_type)
return model_manager.get_default_model(default_type).to_langchain()
# todo: turn into a graph

View file

@ -2,14 +2,13 @@ import streamlit as st
from open_notebook.domain.models import model_manager
from open_notebook.domain.notebook import text_search, vector_search
from open_notebook.models import EmbeddingModel
from pages.stream_app.note import note_list_item
from pages.stream_app.source import source_list_item
from pages.stream_app.utils import setup_page
setup_page("🔍 Search")
EMBEDDING_MODEL: EmbeddingModel = model_manager.get_default_model("embedding")
EMBEDDING_MODEL = model_manager.embedding_model
# search_tab, ask_tab = st.tabs(["Search", "Ask"])
# notebooks = Notebook.get_all()