diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index d66ad19..9cb84ca 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -36,6 +36,7 @@ class DefaultModels(RecordModel): default_speech_to_text_model: Optional[str] = None # default_vision_model: Optional[str] = None default_embedding_model: Optional[str] = None + default_tools_model: Optional[str] = None class ModelManager: @@ -94,7 +95,7 @@ class ModelManager: def refresh_defaults(self): """Refresh the default models from the database""" - self._default_models = DefaultModels.load() + self._default_models = DefaultModels() @property def defaults(self) -> DefaultModels: @@ -146,6 +147,10 @@ class ModelManager: self.defaults.default_transformation_model or self.defaults.default_chat_model ) + elif model_type == "tools": + model_id = ( + self.defaults.default_tools_model or self.defaults.default_chat_model + ) elif model_type == "embedding": model_id = self.defaults.default_embedding_model elif model_type == "text_to_speech": diff --git a/pages/7_⚙️_Settings.py b/pages/7_⚙️_Settings.py index 0d93594..aad2132 100644 --- a/pages/7_⚙️_Settings.py +++ b/pages/7_⚙️_Settings.py @@ -2,7 +2,7 @@ import os import streamlit as st -from open_notebook.domain.models import DefaultModels, Model +from open_notebook.domain.models import DefaultModels, Model, model_manager from open_notebook.models import MODEL_CLASS_MAP from pages.stream_app.utils import setup_page @@ -118,7 +118,7 @@ def get_selected_index(models, model_id, default=0): with model_defaults_tab: - default_models = DefaultModels.load().model_dump() + default_models = DefaultModels().model_dump() all_models = Model.get_all() text_generation_models = [model for model in all_models if model.type == "language"] @@ -154,7 +154,16 @@ with model_defaults_tab: text_generation_models, default_models.get("default_transformation_model") ), ) - st.caption("You can override this model on individual transformations") + st.divider() + defs["default_tools_model"] = st.selectbox( + "Default Tools Model", + text_generation_models, + format_func=lambda x: x.name, + help="This model will be used for calling tools. Currently, it's best to use Open AI and Anthropic for this.", + index=get_selected_index( + text_generation_models, default_models.get("default_tools_model") + ), + ) st.divider() defs["large_context_model"] = st.selectbox( "Large Context Model", @@ -216,4 +225,5 @@ with model_defaults_tab: for k, v in defs.items(): if v: defs[k] = v.id - DefaultModels.update(defs) + DefaultModels().update(defs) + model_manager.refresh_defaults()