improve default_models
This commit is contained in:
parent
4f586ad513
commit
223f1bdaf5
4 changed files with 20 additions and 5 deletions
|
|
@ -105,6 +105,22 @@ class ModelManager:
|
|||
raise RuntimeError("Failed to initialize default models configuration")
|
||||
return self._default_models
|
||||
|
||||
@property
|
||||
def speech_to_text(self, **kwargs) -> SpeechToTextModel:
|
||||
"""Get the default speech-to-text model"""
|
||||
model = self.get_default_model("speech_to_text", **kwargs)
|
||||
if not isinstance(model, SpeechToTextModel):
|
||||
raise TypeError(f"Expected SpeechToTextModel but got {type(model)}")
|
||||
return model
|
||||
|
||||
@property
|
||||
def text_to_speech(self, **kwargs) -> TextToSpeechModel:
|
||||
"""Get the default text-to-speech model"""
|
||||
model = self.get_default_model("text_to_speech", **kwargs)
|
||||
if not isinstance(model, TextToSpeechModel):
|
||||
raise TypeError(f"Expected TextToSpeechModel but got {type(model)}")
|
||||
return model
|
||||
|
||||
@property
|
||||
def embedding_model(self, **kwargs) -> EmbeddingModel:
|
||||
"""Get the default embedding model"""
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
|
|||
|
||||
|
||||
def extract_audio(data: SourceState):
|
||||
SPEECH_TO_TEXT_MODEL = model_manager.get_default_model("speech_to_text")
|
||||
SPEECH_TO_TEXT_MODEL = model_manager.speech_to_text
|
||||
|
||||
input_audio_path = data.get("file_path")
|
||||
audio_files = []
|
||||
|
|
|
|||
|
|
@ -20,11 +20,11 @@ def provision_model(content, config, default_type):
|
|||
logger.debug(
|
||||
f"Using large context model because the content has {tokens} tokens"
|
||||
)
|
||||
return model_manager.get_default_model("large_context")
|
||||
return model_manager.get_default_model("large_context").to_langchain()
|
||||
elif config.get("configurable", {}).get("model_id"):
|
||||
return model_manager.get_model(config.get("configurable", {}).get("model_id"))
|
||||
else:
|
||||
return model_manager.get_default_model(default_type)
|
||||
return model_manager.get_default_model(default_type).to_langchain()
|
||||
|
||||
|
||||
# todo: turn into a graph
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@ import streamlit as st
|
|||
|
||||
from open_notebook.domain.models import model_manager
|
||||
from open_notebook.domain.notebook import text_search, vector_search
|
||||
from open_notebook.models import EmbeddingModel
|
||||
from pages.stream_app.note import note_list_item
|
||||
from pages.stream_app.source import source_list_item
|
||||
from pages.stream_app.utils import setup_page
|
||||
|
||||
setup_page("🔍 Search")
|
||||
|
||||
EMBEDDING_MODEL: EmbeddingModel = model_manager.get_default_model("embedding")
|
||||
EMBEDDING_MODEL = model_manager.embedding_model
|
||||
|
||||
# search_tab, ask_tab = st.tabs(["Search", "Ask"])
|
||||
# notebooks = Notebook.get_all()
|
||||
|
|
|
|||
Loading…
Reference in a new issue