add concept of tools model

This commit is contained in:
LUIS NOVO 2024-11-04 09:50:15 -03:00
parent 1f23ba5490
commit 62cd5a9dfb
2 changed files with 20 additions and 5 deletions

View file

@ -36,6 +36,7 @@ class DefaultModels(RecordModel):
default_speech_to_text_model: Optional[str] = None
# default_vision_model: Optional[str] = None
default_embedding_model: Optional[str] = None
default_tools_model: Optional[str] = None
class ModelManager:
@ -94,7 +95,7 @@ class ModelManager:
def refresh_defaults(self):
"""Refresh the default models from the database"""
self._default_models = DefaultModels.load()
self._default_models = DefaultModels()
@property
def defaults(self) -> DefaultModels:
@ -146,6 +147,10 @@ class ModelManager:
self.defaults.default_transformation_model
or self.defaults.default_chat_model
)
elif model_type == "tools":
model_id = (
self.defaults.default_tools_model or self.defaults.default_chat_model
)
elif model_type == "embedding":
model_id = self.defaults.default_embedding_model
elif model_type == "text_to_speech":

View file

@ -2,7 +2,7 @@ import os
import streamlit as st
from open_notebook.domain.models import DefaultModels, Model
from open_notebook.domain.models import DefaultModels, Model, model_manager
from open_notebook.models import MODEL_CLASS_MAP
from pages.stream_app.utils import setup_page
@ -118,7 +118,7 @@ def get_selected_index(models, model_id, default=0):
with model_defaults_tab:
default_models = DefaultModels.load().model_dump()
default_models = DefaultModels().model_dump()
all_models = Model.get_all()
text_generation_models = [model for model in all_models if model.type == "language"]
@ -154,7 +154,16 @@ with model_defaults_tab:
text_generation_models, default_models.get("default_transformation_model")
),
)
st.caption("You can override this model on individual transformations")
st.divider()
defs["default_tools_model"] = st.selectbox(
"Default Tools Model",
text_generation_models,
format_func=lambda x: x.name,
help="This model will be used for calling tools. Currently, it's best to use Open AI and Anthropic for this.",
index=get_selected_index(
text_generation_models, default_models.get("default_tools_model")
),
)
st.divider()
defs["large_context_model"] = st.selectbox(
"Large Context Model",
@ -216,4 +225,5 @@ with model_defaults_tab:
for k, v in defs.items():
if v:
defs[k] = v.id
DefaultModels.update(defs)
DefaultModels().update(defs)
model_manager.refresh_defaults()