From 223f1bdaf5b4aa32b6de181632128cd2756a3c7e Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 22:38:21 -0300 Subject: [PATCH] improve default_models --- open_notebook/domain/models.py | 16 ++++++++++++++++ open_notebook/graphs/content_processing/audio.py | 2 +- open_notebook/graphs/utils.py | 4 ++-- pages/3_🔍_Search.py | 3 +-- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index cf6c8f9..d66ad19 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -105,6 +105,22 @@ class ModelManager: raise RuntimeError("Failed to initialize default models configuration") return self._default_models + @property + def speech_to_text(self, **kwargs) -> SpeechToTextModel: + """Get the default speech-to-text model""" + model = self.get_default_model("speech_to_text", **kwargs) + if not isinstance(model, SpeechToTextModel): + raise TypeError(f"Expected SpeechToTextModel but got {type(model)}") + return model + + @property + def text_to_speech(self, **kwargs) -> TextToSpeechModel: + """Get the default text-to-speech model""" + model = self.get_default_model("text_to_speech", **kwargs) + if not isinstance(model, TextToSpeechModel): + raise TypeError(f"Expected TextToSpeechModel but got {type(model)}") + return model + @property def embedding_model(self, **kwargs) -> EmbeddingModel: """Get the default embedding model""" diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index be8c441..3f99277 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -73,7 +73,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None): def extract_audio(data: SourceState): - SPEECH_TO_TEXT_MODEL = model_manager.get_default_model("speech_to_text") + SPEECH_TO_TEXT_MODEL = model_manager.speech_to_text input_audio_path = data.get("file_path") audio_files = [] diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index d67d5a6..9429bca 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -20,11 +20,11 @@ def provision_model(content, config, default_type): logger.debug( f"Using large context model because the content has {tokens} tokens" ) - return model_manager.get_default_model("large_context") + return model_manager.get_default_model("large_context").to_langchain() elif config.get("configurable", {}).get("model_id"): return model_manager.get_model(config.get("configurable", {}).get("model_id")) else: - return model_manager.get_default_model(default_type) + return model_manager.get_default_model(default_type).to_langchain() # todo: turn into a graph diff --git a/pages/3_🔍_Search.py b/pages/3_🔍_Search.py index a9abad5..4438af5 100644 --- a/pages/3_🔍_Search.py +++ b/pages/3_🔍_Search.py @@ -2,14 +2,13 @@ import streamlit as st from open_notebook.domain.models import model_manager from open_notebook.domain.notebook import text_search, vector_search -from open_notebook.models import EmbeddingModel from pages.stream_app.note import note_list_item from pages.stream_app.source import source_list_item from pages.stream_app.utils import setup_page setup_page("🔍 Search") -EMBEDDING_MODEL: EmbeddingModel = model_manager.get_default_model("embedding") +EMBEDDING_MODEL = model_manager.embedding_model # search_tab, ask_tab = st.tabs(["Search", "Ask"]) # notebooks = Notebook.get_all()