model selector and model suggestions
This commit is contained in:
parent
80353a97c9
commit
9ba5709a3c
8 changed files with 182 additions and 115 deletions
|
|
@ -112,8 +112,6 @@ class ObjectModel(BaseModel):
|
|||
from open_notebook.domain.models import model_manager
|
||||
from open_notebook.models import EmbeddingModel
|
||||
|
||||
EMBEDDING_MODEL: EmbeddingModel = model_manager.embedding_model
|
||||
|
||||
try:
|
||||
self.model_validate(self.model_dump(), strict=True)
|
||||
data = self._prepare_save_data()
|
||||
|
|
@ -122,6 +120,7 @@ class ObjectModel(BaseModel):
|
|||
if self.needs_embedding():
|
||||
embedding_content = self.get_embedding_content()
|
||||
if embedding_content:
|
||||
EMBEDDING_MODEL: EmbeddingModel = model_manager.embedding_model
|
||||
data["embedding"] = EMBEDDING_MODEL.embed(embedding_content)
|
||||
|
||||
if self.id is None:
|
||||
|
|
|
|||
|
|
@ -68,7 +68,9 @@ class ModelManager:
|
|||
)
|
||||
return cached_model
|
||||
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
if not model_id:
|
||||
return None
|
||||
|
||||
model: Model = Model.get(model_id)
|
||||
|
||||
if not model:
|
||||
|
|
@ -160,9 +162,6 @@ class ModelManager:
|
|||
elif model_type == "large_context":
|
||||
model_id = self.defaults.large_context_model
|
||||
|
||||
if not model_id:
|
||||
raise ValueError(f"No default model configured for type: {model_type}")
|
||||
|
||||
return self.get_model(model_id, **kwargs)
|
||||
|
||||
def clear_cache(self):
|
||||
|
|
|
|||
|
|
@ -10,4 +10,28 @@ youtube_transcripts:
|
|||
- fr
|
||||
- de
|
||||
- hi
|
||||
- ja
|
||||
- ja
|
||||
suggested_models:
|
||||
openai:
|
||||
language:
|
||||
- gpt-4o-mini
|
||||
embedding:
|
||||
- text-embedding-3-small
|
||||
text_to_speech:
|
||||
- tts-1-hd
|
||||
speech_to_text:
|
||||
- whisper-1
|
||||
gemini:
|
||||
language:
|
||||
- gemini-1.5-flash
|
||||
text_to_speech:
|
||||
- default
|
||||
xai:
|
||||
language:
|
||||
- grok-beta
|
||||
anthropic:
|
||||
language:
|
||||
- claude-3-5-sonnet-20241022
|
||||
elevenlabs:
|
||||
text_to_speech:
|
||||
- eleven_turbo_v2_5
|
||||
|
|
@ -2,9 +2,10 @@ import asyncio
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.domain.notebook import Note, Notebook, text_search, vector_search
|
||||
from open_notebook.graphs.ask import graph as ask_graph
|
||||
from pages.components.model_selector import model_selector
|
||||
from pages.stream_app.utils import convert_source_references, setup_page
|
||||
|
||||
setup_page("🔍 Search")
|
||||
|
|
@ -52,23 +53,26 @@ with ask_tab:
|
|||
"The LLM will answer your query based on the documents in your knowledge base. "
|
||||
)
|
||||
question = st.text_input("Question", "")
|
||||
models = Model.get_models_by_type("language")
|
||||
strategy_model: Model = st.selectbox(
|
||||
default_model = DefaultModels().load().default_chat_model
|
||||
strategy_model = model_selector(
|
||||
"Query Strategy Model",
|
||||
models,
|
||||
format_func=lambda x: x.name,
|
||||
"strategy_model",
|
||||
selected_id=default_model,
|
||||
model_type="language",
|
||||
help="This is the LLM that will be responsible for strategizing the search",
|
||||
)
|
||||
answer_model: Model = st.selectbox(
|
||||
"Indivual Answer Model",
|
||||
models,
|
||||
format_func=lambda x: x.name,
|
||||
answer_model = model_selector(
|
||||
"Individual Answer Model",
|
||||
"answer_model",
|
||||
model_type="language",
|
||||
selected_id=default_model,
|
||||
help="This is the LLM that will be responsible for processing individual subqueries",
|
||||
)
|
||||
final_answer_model: Model = st.selectbox(
|
||||
final_answer_model = model_selector(
|
||||
"Final Answer Model",
|
||||
models,
|
||||
format_func=lambda x: x.name,
|
||||
"final_answer_model",
|
||||
model_type="language",
|
||||
selected_id=default_model,
|
||||
help="This is the LLM that will be responsible for processing the final answer",
|
||||
)
|
||||
ask_bt = st.button("Ask")
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ import os
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from open_notebook.config import CONFIG
|
||||
from open_notebook.domain.models import DefaultModels, Model, model_manager
|
||||
from open_notebook.domain.transformation import DefaultTransformations, Transformation
|
||||
from open_notebook.models import MODEL_CLASS_MAP
|
||||
from pages.components.model_selector import model_selector
|
||||
from pages.stream_app.utils import setup_page
|
||||
|
||||
setup_page("⚙️ Settings")
|
||||
|
|
@ -59,8 +61,41 @@ provider_status["litellm"] = (
|
|||
available_providers = [k for k, v in provider_status.items() if v]
|
||||
unavailable_providers = [k for k, v in provider_status.items() if not v]
|
||||
|
||||
|
||||
def generate_new_models(models, suggested_models):
|
||||
# Create a set of existing model keys for efficient lookup
|
||||
existing_model_keys = {
|
||||
f"{model.provider}-{model.name}-{model.type}" for model in models
|
||||
}
|
||||
|
||||
new_models = []
|
||||
|
||||
# Iterate through suggested models by provider
|
||||
for provider, types in suggested_models.items():
|
||||
# Iterate through each type (language, embedding, etc.)
|
||||
for type_, model_list in types.items():
|
||||
for model_name in model_list:
|
||||
model_key = f"{provider}-{model_name}-{type_}"
|
||||
|
||||
# Check if model already exists
|
||||
if model_key not in existing_model_keys:
|
||||
new_models.append(
|
||||
{
|
||||
"name": model_name,
|
||||
"type": type_,
|
||||
"provider": provider,
|
||||
}
|
||||
)
|
||||
|
||||
return new_models
|
||||
|
||||
|
||||
default_models = DefaultModels().model_dump()
|
||||
all_models = Model.get_all()
|
||||
|
||||
with model_tab:
|
||||
st.subheader("Add Model")
|
||||
|
||||
provider = st.selectbox("Provider", available_providers)
|
||||
if len(unavailable_providers) > 0:
|
||||
st.caption(
|
||||
|
|
@ -92,8 +127,20 @@ with model_tab:
|
|||
model = Model(name=model_name, provider=provider, type=model_type)
|
||||
model.save()
|
||||
st.success("Saved")
|
||||
|
||||
st.divider()
|
||||
all_models = Model.get_all()
|
||||
suggested_models = CONFIG.get("suggested_models", [])
|
||||
recommendations = generate_new_models(all_models, suggested_models)
|
||||
if len(recommendations) > 0:
|
||||
with st.expander("💁♂️ Recommended models to get you started.."):
|
||||
for recommendation in recommendations:
|
||||
st.markdown(
|
||||
f"**{recommendation['name']}** ({recommendation['provider']}, {recommendation['type']})"
|
||||
)
|
||||
if st.button("Add", key=f"add_{recommendation['name']}"):
|
||||
new_model = Model(**recommendation)
|
||||
new_model.save()
|
||||
st.rerun()
|
||||
st.subheader("Configured Models")
|
||||
model_types_available = {
|
||||
# "vision": False,
|
||||
|
|
@ -114,20 +161,7 @@ with model_tab:
|
|||
if not available:
|
||||
st.warning(f"No models available for {model_type}")
|
||||
|
||||
|
||||
def get_selected_index(models, model_id, default=0):
|
||||
"""Returns the index of the selected model in the list of models"""
|
||||
if not model_id or not models:
|
||||
return default
|
||||
for i, model in enumerate(models):
|
||||
if model.id == model_id:
|
||||
return i
|
||||
return default
|
||||
|
||||
|
||||
with model_defaults_tab:
|
||||
default_models = DefaultModels().model_dump()
|
||||
all_models = Model.get_all()
|
||||
text_generation_models = [model for model in all_models if model.type == "language"]
|
||||
|
||||
text_to_speech_models = [
|
||||
|
|
@ -143,93 +177,80 @@ with model_defaults_tab:
|
|||
"In this section, you can select the default models to be used on the various content operations done by Open Notebook. Some of these can be overriden in the different modules."
|
||||
)
|
||||
defs = {}
|
||||
defs["default_chat_model"] = st.selectbox(
|
||||
defs["default_chat_model"] = model_selector(
|
||||
"Default Chat Model",
|
||||
text_generation_models,
|
||||
format_func=lambda x: x.name,
|
||||
"default_chat_model",
|
||||
selected_id=default_models.get("default_chat_model"),
|
||||
help="This model will be used for chat.",
|
||||
index=get_selected_index(
|
||||
text_generation_models, default_models.get("default_chat_model")
|
||||
),
|
||||
model_type="language",
|
||||
)
|
||||
st.divider()
|
||||
defs["default_transformation_model"] = st.selectbox(
|
||||
defs["default_transformation_model"] = model_selector(
|
||||
"Default Transformation Model",
|
||||
text_generation_models,
|
||||
format_func=lambda x: x.name,
|
||||
"default_transformation_model",
|
||||
selected_id=default_models.get("default_transformation_model"),
|
||||
help="This model will be used for text transformations such as summaries, insights, etc.",
|
||||
index=get_selected_index(
|
||||
text_generation_models, default_models.get("default_transformation_model")
|
||||
),
|
||||
model_type="language",
|
||||
)
|
||||
st.caption("You can use a cheap model here like gpt-4o-mini, llama3, etc.")
|
||||
st.divider()
|
||||
defs["default_tools_model"] = st.selectbox(
|
||||
defs["default_tools_model"] = model_selector(
|
||||
"Default Tools Model",
|
||||
text_generation_models,
|
||||
format_func=lambda x: x.name,
|
||||
"default_tools_model",
|
||||
selected_id=default_models.get("default_tools_model"),
|
||||
help="This model will be used for calling tools. Currently, it's best to use Open AI and Anthropic for this.",
|
||||
index=get_selected_index(
|
||||
text_generation_models, default_models.get("default_tools_model")
|
||||
),
|
||||
model_type="language",
|
||||
)
|
||||
st.caption("Recommended to use a capable model here, like gpt-4o, claude, etc.")
|
||||
st.divider()
|
||||
defs["large_context_model"] = st.selectbox(
|
||||
defs["large_context_model"] = model_selector(
|
||||
"Large Context Model",
|
||||
text_generation_models,
|
||||
format_func=lambda x: x.name,
|
||||
"large_context_model",
|
||||
selected_id=default_models.get("large_context_model"),
|
||||
help="This model will be used for larger context generation -- recommended: Gemini",
|
||||
index=get_selected_index(
|
||||
text_generation_models, default_models.get("large_context_model")
|
||||
),
|
||||
model_type="language",
|
||||
)
|
||||
st.caption("Recommended to use Gemini models for larger context processing")
|
||||
st.divider()
|
||||
defs["default_text_to_speech_model"] = st.selectbox(
|
||||
defs["default_text_to_speech_model"] = model_selector(
|
||||
"Default Text to Speech Model",
|
||||
text_to_speech_models,
|
||||
format_func=lambda x: x.name,
|
||||
"default_text_to_speech_model",
|
||||
selected_id=default_models.get("default_text_to_speech_model"),
|
||||
help="This is the default model for converting text to speech (podcasts, etc)",
|
||||
index=get_selected_index(
|
||||
text_to_speech_models, default_models.get("default_text_to_speech_model")
|
||||
),
|
||||
model_type="text_to_speech",
|
||||
)
|
||||
st.caption("You can override this model on different podcasts")
|
||||
st.divider()
|
||||
defs["default_speech_to_text_model"] = st.selectbox(
|
||||
defs["default_speech_to_text_model"] = model_selector(
|
||||
"Default Speech to Text Model",
|
||||
speech_to_text_models,
|
||||
format_func=lambda x: x.name,
|
||||
"default_speech_to_text_model",
|
||||
selected_id=default_models.get("default_speech_to_text_model"),
|
||||
help="This is the default model for converting speech to text (audio transcriptions, etc)",
|
||||
index=get_selected_index(
|
||||
speech_to_text_models, default_models.get("default_speech_to_text_model")
|
||||
),
|
||||
model_type="speech_to_text",
|
||||
)
|
||||
|
||||
st.divider()
|
||||
# defs["default_vision_model"] = st.selectbox(
|
||||
# "Default Vision Model",
|
||||
# vision_models,
|
||||
# format_func=lambda x: x.name,
|
||||
# help="This is the default model for vision tasks (image recognition, PDF recognition, etc)",
|
||||
# index=get_selected_index(
|
||||
# vision_models, default_models.get("default_vision_model")
|
||||
# defs["default_vision_model"] = (
|
||||
# model_selector(
|
||||
# "Default Speech to Text Model",
|
||||
# "default_vision_model",
|
||||
# selected_id=default_models.get("default_vision_model"),
|
||||
# help="This is the default model for vision tasks",
|
||||
# model_type="vision",
|
||||
# ),
|
||||
# )
|
||||
# st.divider()
|
||||
|
||||
defs["default_embedding_model"] = st.selectbox(
|
||||
"Default Embedding Model",
|
||||
embedding_models,
|
||||
format_func=lambda x: x.name,
|
||||
defs["default_embedding_model"] = model_selector(
|
||||
"Default Speech to Text Model",
|
||||
"default_embedding_model",
|
||||
selected_id=default_models.get("default_embedding_model"),
|
||||
help="This is the default model for embeddings (semantic search, etc)",
|
||||
index=get_selected_index(
|
||||
embedding_models, default_models.get("default_embedding_model")
|
||||
),
|
||||
model_type="embedding",
|
||||
)
|
||||
st.caption(
|
||||
"Caution: you cannot change the embedding model once there is embeddings or they will need to be regenerated"
|
||||
)
|
||||
|
||||
# if st.button("Save Defaults", key="save_defaults"):
|
||||
for k, v in defs.items():
|
||||
if v:
|
||||
defs[k] = v.id
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import streamlit as st
|
||||
import yaml
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
from open_notebook.graphs.multipattern import graph as pattern_graph
|
||||
from pages.components.model_selector import model_selector
|
||||
from pages.stream_app.utils import setup_page
|
||||
|
||||
setup_page("🛝 Playground")
|
||||
|
|
@ -22,12 +22,13 @@ transformation = st.selectbox(
|
|||
with st.expander("Details"):
|
||||
st.json(transformation)
|
||||
|
||||
models = Model.get_models_by_type("language")
|
||||
model = st.selectbox(
|
||||
model = model_selector(
|
||||
"Pick a pattern model",
|
||||
models,
|
||||
format_func=lambda x: x.name,
|
||||
key="model",
|
||||
help="This is the model that will be used to run the transformation",
|
||||
model_type="language",
|
||||
)
|
||||
|
||||
input_text = st.text_area("Enter some text", height=200)
|
||||
|
||||
if st.button("Run"):
|
||||
|
|
|
|||
35
pages/components/model_selector.py
Normal file
35
pages/components/model_selector.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Literal
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
|
||||
|
||||
def model_selector(
|
||||
label,
|
||||
key,
|
||||
selected_id=None,
|
||||
help=None,
|
||||
model_type: Literal[
|
||||
"language", "embedding", "speech_to_text", "text_to_speech"
|
||||
] = "language",
|
||||
) -> Model:
|
||||
models = Model.get_models_by_type(model_type)
|
||||
models.sort(key=lambda x: (x.provider, x.name))
|
||||
try:
|
||||
index = (
|
||||
next((i for i, m in enumerate(models) if m.id == selected_id), 0)
|
||||
if selected_id
|
||||
else 0
|
||||
)
|
||||
except Exception:
|
||||
index = 0
|
||||
|
||||
return st.selectbox(
|
||||
label,
|
||||
models,
|
||||
format_func=lambda x: f"{x.provider} - {x.name}",
|
||||
help=help,
|
||||
index=index,
|
||||
key=key,
|
||||
)
|
||||
|
|
@ -116,34 +116,18 @@ def check_migration():
|
|||
|
||||
def check_models():
|
||||
default_models = model_manager.defaults
|
||||
if (
|
||||
not default_models.default_chat_model
|
||||
or not default_models.default_transformation_model
|
||||
if not all(
|
||||
[
|
||||
default_models.default_chat_model,
|
||||
default_models.default_transformation_model,
|
||||
default_models.default_embedding_model,
|
||||
default_models.default_speech_to_text_model,
|
||||
default_models.large_context_model,
|
||||
]
|
||||
):
|
||||
st.warning(
|
||||
"You don't have default chat and transformation models selected. Please, select them on the settings page."
|
||||
"You are missing some default models and the app might not work as expected. Please, select them on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
elif not default_models.default_embedding_model:
|
||||
st.warning(
|
||||
"You don't have a default embedding model selected. Vector search will not be possible and your assistant will be less able to answer your queries. Please, select one on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
elif not default_models.default_speech_to_text_model:
|
||||
st.warning(
|
||||
"You don't have a default speech to text model selected. Your assistant will not be able to transcribe audio. Please, select one on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
elif not default_models.default_text_to_speech_model:
|
||||
st.warning(
|
||||
"You don't have a default text to speech model selected. Your assistant will not be able to generate audio and podcasts. Please, select one on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
elif not default_models.large_context_model:
|
||||
st.warning(
|
||||
"You don't have a large context model selected. Your assistant will not be able to process large documents. Please, select one on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
|
||||
|
||||
def handle_error(func):
|
||||
|
|
|
|||
Loading…
Reference in a new issue