From 9ba5709a3cbe76f8f8d2989a4d2551d1bbba038f Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Wed, 13 Nov 2024 14:48:00 -0300 Subject: [PATCH] model selector and model suggestions --- open_notebook/domain/base.py | 3 +- open_notebook/domain/models.py | 7 +- open_notebook_config.yaml | 26 ++++- pages/3_🔍_Ask_and_Search.py | 28 +++--- pages/7_⚙️_Settings.py | 153 ++++++++++++++++------------- pages/8_🛝_Playground.py | 11 ++- pages/components/model_selector.py | 35 +++++++ pages/stream_app/utils.py | 34 ++----- 8 files changed, 182 insertions(+), 115 deletions(-) create mode 100644 pages/components/model_selector.py diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index 1b46704..4757fdf 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -112,8 +112,6 @@ class ObjectModel(BaseModel): from open_notebook.domain.models import model_manager from open_notebook.models import EmbeddingModel - EMBEDDING_MODEL: EmbeddingModel = model_manager.embedding_model - try: self.model_validate(self.model_dump(), strict=True) data = self._prepare_save_data() @@ -122,6 +120,7 @@ class ObjectModel(BaseModel): if self.needs_embedding(): embedding_content = self.get_embedding_content() if embedding_content: + EMBEDDING_MODEL: EmbeddingModel = model_manager.embedding_model data["embedding"] = EMBEDDING_MODEL.embed(embedding_content) if self.id is None: diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index 9cb84ca..ff43ee7 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -68,7 +68,9 @@ class ModelManager: ) return cached_model - assert model_id, "Model ID cannot be empty" + if not model_id: + return None + model: Model = Model.get(model_id) if not model: @@ -160,9 +162,6 @@ class ModelManager: elif model_type == "large_context": model_id = self.defaults.large_context_model - if not model_id: - raise ValueError(f"No default model configured for type: {model_type}") - return self.get_model(model_id, **kwargs) def clear_cache(self): diff --git a/open_notebook_config.yaml b/open_notebook_config.yaml index 280c2a4..e6d6f72 100644 --- a/open_notebook_config.yaml +++ b/open_notebook_config.yaml @@ -10,4 +10,28 @@ youtube_transcripts: - fr - de - hi - - ja \ No newline at end of file + - ja +suggested_models: + openai: + language: + - gpt-4o-mini + embedding: + - text-embedding-3-small + text_to_speech: + - tts-1-hd + speech_to_text: + - whisper-1 + gemini: + language: + - gemini-1.5-flash + text_to_speech: + - default + xai: + language: + - grok-beta + anthropic: + language: + - claude-3-5-sonnet-20241022 + elevenlabs: + text_to_speech: + - eleven_turbo_v2_5 \ No newline at end of file diff --git a/pages/3_🔍_Ask_and_Search.py b/pages/3_🔍_Ask_and_Search.py index d84e404..7956737 100644 --- a/pages/3_🔍_Ask_and_Search.py +++ b/pages/3_🔍_Ask_and_Search.py @@ -2,9 +2,10 @@ import asyncio import streamlit as st -from open_notebook.domain.models import Model +from open_notebook.domain.models import DefaultModels from open_notebook.domain.notebook import Note, Notebook, text_search, vector_search from open_notebook.graphs.ask import graph as ask_graph +from pages.components.model_selector import model_selector from pages.stream_app.utils import convert_source_references, setup_page setup_page("🔍 Search") @@ -52,23 +53,26 @@ with ask_tab: "The LLM will answer your query based on the documents in your knowledge base. " ) question = st.text_input("Question", "") - models = Model.get_models_by_type("language") - strategy_model: Model = st.selectbox( + default_model = DefaultModels().load().default_chat_model + strategy_model = model_selector( "Query Strategy Model", - models, - format_func=lambda x: x.name, + "strategy_model", + selected_id=default_model, + model_type="language", help="This is the LLM that will be responsible for strategizing the search", ) - answer_model: Model = st.selectbox( - "Indivual Answer Model", - models, - format_func=lambda x: x.name, + answer_model = model_selector( + "Individual Answer Model", + "answer_model", + model_type="language", + selected_id=default_model, help="This is the LLM that will be responsible for processing individual subqueries", ) - final_answer_model: Model = st.selectbox( + final_answer_model = model_selector( "Final Answer Model", - models, - format_func=lambda x: x.name, + "final_answer_model", + model_type="language", + selected_id=default_model, help="This is the LLM that will be responsible for processing the final answer", ) ask_bt = st.button("Ask") diff --git a/pages/7_⚙️_Settings.py b/pages/7_⚙️_Settings.py index e2b09f0..955c995 100644 --- a/pages/7_⚙️_Settings.py +++ b/pages/7_⚙️_Settings.py @@ -2,9 +2,11 @@ import os import streamlit as st +from open_notebook.config import CONFIG from open_notebook.domain.models import DefaultModels, Model, model_manager from open_notebook.domain.transformation import DefaultTransformations, Transformation from open_notebook.models import MODEL_CLASS_MAP +from pages.components.model_selector import model_selector from pages.stream_app.utils import setup_page setup_page("⚙️ Settings") @@ -59,8 +61,41 @@ provider_status["litellm"] = ( available_providers = [k for k, v in provider_status.items() if v] unavailable_providers = [k for k, v in provider_status.items() if not v] + +def generate_new_models(models, suggested_models): + # Create a set of existing model keys for efficient lookup + existing_model_keys = { + f"{model.provider}-{model.name}-{model.type}" for model in models + } + + new_models = [] + + # Iterate through suggested models by provider + for provider, types in suggested_models.items(): + # Iterate through each type (language, embedding, etc.) + for type_, model_list in types.items(): + for model_name in model_list: + model_key = f"{provider}-{model_name}-{type_}" + + # Check if model already exists + if model_key not in existing_model_keys: + new_models.append( + { + "name": model_name, + "type": type_, + "provider": provider, + } + ) + + return new_models + + +default_models = DefaultModels().model_dump() +all_models = Model.get_all() + with model_tab: st.subheader("Add Model") + provider = st.selectbox("Provider", available_providers) if len(unavailable_providers) > 0: st.caption( @@ -92,8 +127,20 @@ with model_tab: model = Model(name=model_name, provider=provider, type=model_type) model.save() st.success("Saved") + st.divider() - all_models = Model.get_all() + suggested_models = CONFIG.get("suggested_models", []) + recommendations = generate_new_models(all_models, suggested_models) + if len(recommendations) > 0: + with st.expander("💁‍♂️ Recommended models to get you started.."): + for recommendation in recommendations: + st.markdown( + f"**{recommendation['name']}** ({recommendation['provider']}, {recommendation['type']})" + ) + if st.button("Add", key=f"add_{recommendation['name']}"): + new_model = Model(**recommendation) + new_model.save() + st.rerun() st.subheader("Configured Models") model_types_available = { # "vision": False, @@ -114,20 +161,7 @@ with model_tab: if not available: st.warning(f"No models available for {model_type}") - -def get_selected_index(models, model_id, default=0): - """Returns the index of the selected model in the list of models""" - if not model_id or not models: - return default - for i, model in enumerate(models): - if model.id == model_id: - return i - return default - - with model_defaults_tab: - default_models = DefaultModels().model_dump() - all_models = Model.get_all() text_generation_models = [model for model in all_models if model.type == "language"] text_to_speech_models = [ @@ -143,93 +177,80 @@ with model_defaults_tab: "In this section, you can select the default models to be used on the various content operations done by Open Notebook. Some of these can be overriden in the different modules." ) defs = {} - defs["default_chat_model"] = st.selectbox( + defs["default_chat_model"] = model_selector( "Default Chat Model", - text_generation_models, - format_func=lambda x: x.name, + "default_chat_model", + selected_id=default_models.get("default_chat_model"), help="This model will be used for chat.", - index=get_selected_index( - text_generation_models, default_models.get("default_chat_model") - ), + model_type="language", ) st.divider() - defs["default_transformation_model"] = st.selectbox( + defs["default_transformation_model"] = model_selector( "Default Transformation Model", - text_generation_models, - format_func=lambda x: x.name, + "default_transformation_model", + selected_id=default_models.get("default_transformation_model"), help="This model will be used for text transformations such as summaries, insights, etc.", - index=get_selected_index( - text_generation_models, default_models.get("default_transformation_model") - ), + model_type="language", ) + st.caption("You can use a cheap model here like gpt-4o-mini, llama3, etc.") st.divider() - defs["default_tools_model"] = st.selectbox( + defs["default_tools_model"] = model_selector( "Default Tools Model", - text_generation_models, - format_func=lambda x: x.name, + "default_tools_model", + selected_id=default_models.get("default_tools_model"), help="This model will be used for calling tools. Currently, it's best to use Open AI and Anthropic for this.", - index=get_selected_index( - text_generation_models, default_models.get("default_tools_model") - ), + model_type="language", ) + st.caption("Recommended to use a capable model here, like gpt-4o, claude, etc.") st.divider() - defs["large_context_model"] = st.selectbox( + defs["large_context_model"] = model_selector( "Large Context Model", - text_generation_models, - format_func=lambda x: x.name, + "large_context_model", + selected_id=default_models.get("large_context_model"), help="This model will be used for larger context generation -- recommended: Gemini", - index=get_selected_index( - text_generation_models, default_models.get("large_context_model") - ), + model_type="language", ) st.caption("Recommended to use Gemini models for larger context processing") st.divider() - defs["default_text_to_speech_model"] = st.selectbox( + defs["default_text_to_speech_model"] = model_selector( "Default Text to Speech Model", - text_to_speech_models, - format_func=lambda x: x.name, + "default_text_to_speech_model", + selected_id=default_models.get("default_text_to_speech_model"), help="This is the default model for converting text to speech (podcasts, etc)", - index=get_selected_index( - text_to_speech_models, default_models.get("default_text_to_speech_model") - ), + model_type="text_to_speech", ) st.caption("You can override this model on different podcasts") st.divider() - defs["default_speech_to_text_model"] = st.selectbox( + defs["default_speech_to_text_model"] = model_selector( "Default Speech to Text Model", - speech_to_text_models, - format_func=lambda x: x.name, + "default_speech_to_text_model", + selected_id=default_models.get("default_speech_to_text_model"), help="This is the default model for converting speech to text (audio transcriptions, etc)", - index=get_selected_index( - speech_to_text_models, default_models.get("default_speech_to_text_model") - ), + model_type="speech_to_text", ) + st.divider() - # defs["default_vision_model"] = st.selectbox( - # "Default Vision Model", - # vision_models, - # format_func=lambda x: x.name, - # help="This is the default model for vision tasks (image recognition, PDF recognition, etc)", - # index=get_selected_index( - # vision_models, default_models.get("default_vision_model") + # defs["default_vision_model"] = ( + # model_selector( + # "Default Speech to Text Model", + # "default_vision_model", + # selected_id=default_models.get("default_vision_model"), + # help="This is the default model for vision tasks", + # model_type="vision", # ), # ) - # st.divider() - defs["default_embedding_model"] = st.selectbox( - "Default Embedding Model", - embedding_models, - format_func=lambda x: x.name, + defs["default_embedding_model"] = model_selector( + "Default Speech to Text Model", + "default_embedding_model", + selected_id=default_models.get("default_embedding_model"), help="This is the default model for embeddings (semantic search, etc)", - index=get_selected_index( - embedding_models, default_models.get("default_embedding_model") - ), + model_type="embedding", ) st.caption( "Caution: you cannot change the embedding model once there is embeddings or they will need to be regenerated" ) - # if st.button("Save Defaults", key="save_defaults"): for k, v in defs.items(): if v: defs[k] = v.id diff --git a/pages/8_🛝_Playground.py b/pages/8_🛝_Playground.py index 5bcac7a..355cf9b 100644 --- a/pages/8_🛝_Playground.py +++ b/pages/8_🛝_Playground.py @@ -1,8 +1,8 @@ import streamlit as st import yaml -from open_notebook.domain.models import Model from open_notebook.graphs.multipattern import graph as pattern_graph +from pages.components.model_selector import model_selector from pages.stream_app.utils import setup_page setup_page("🛝 Playground") @@ -22,12 +22,13 @@ transformation = st.selectbox( with st.expander("Details"): st.json(transformation) -models = Model.get_models_by_type("language") -model = st.selectbox( +model = model_selector( "Pick a pattern model", - models, - format_func=lambda x: x.name, + key="model", + help="This is the model that will be used to run the transformation", + model_type="language", ) + input_text = st.text_area("Enter some text", height=200) if st.button("Run"): diff --git a/pages/components/model_selector.py b/pages/components/model_selector.py new file mode 100644 index 0000000..832367f --- /dev/null +++ b/pages/components/model_selector.py @@ -0,0 +1,35 @@ +from typing import Literal + +import streamlit as st + +from open_notebook.domain.models import Model + + +def model_selector( + label, + key, + selected_id=None, + help=None, + model_type: Literal[ + "language", "embedding", "speech_to_text", "text_to_speech" + ] = "language", +) -> Model: + models = Model.get_models_by_type(model_type) + models.sort(key=lambda x: (x.provider, x.name)) + try: + index = ( + next((i for i, m in enumerate(models) if m.id == selected_id), 0) + if selected_id + else 0 + ) + except Exception: + index = 0 + + return st.selectbox( + label, + models, + format_func=lambda x: f"{x.provider} - {x.name}", + help=help, + index=index, + key=key, + ) diff --git a/pages/stream_app/utils.py b/pages/stream_app/utils.py index 2caae54..db5df6d 100644 --- a/pages/stream_app/utils.py +++ b/pages/stream_app/utils.py @@ -116,34 +116,18 @@ def check_migration(): def check_models(): default_models = model_manager.defaults - if ( - not default_models.default_chat_model - or not default_models.default_transformation_model + if not all( + [ + default_models.default_chat_model, + default_models.default_transformation_model, + default_models.default_embedding_model, + default_models.default_speech_to_text_model, + default_models.large_context_model, + ] ): st.warning( - "You don't have default chat and transformation models selected. Please, select them on the settings page." + "You are missing some default models and the app might not work as expected. Please, select them on the settings page." ) - st.stop() - elif not default_models.default_embedding_model: - st.warning( - "You don't have a default embedding model selected. Vector search will not be possible and your assistant will be less able to answer your queries. Please, select one on the settings page." - ) - st.stop() - elif not default_models.default_speech_to_text_model: - st.warning( - "You don't have a default speech to text model selected. Your assistant will not be able to transcribe audio. Please, select one on the settings page." - ) - st.stop() - elif not default_models.default_text_to_speech_model: - st.warning( - "You don't have a default text to speech model selected. Your assistant will not be able to generate audio and podcasts. Please, select one on the settings page." - ) - st.stop() - elif not default_models.large_context_model: - st.warning( - "You don't have a large context model selected. Your assistant will not be able to process large documents. Please, select one on the settings page." - ) - st.stop() def handle_error(func):