diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index dc2cf36..f0a1e27 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -30,6 +30,7 @@ MODEL_CLASS_MAP = { "speech_to_text": { "openai": OpenAISpeechToTextModel, }, + "text_to_speech": {"openai": None, "elevenlabs": None}, } diff --git a/pages/9_⚙️_Settings.py b/pages/9_⚙️_Settings.py index 110bf7c..aa76b50 100644 --- a/pages/9_⚙️_Settings.py +++ b/pages/9_⚙️_Settings.py @@ -3,6 +3,7 @@ import os import streamlit as st from open_notebook.domain.models import DefaultModels, Model +from open_notebook.models import MODEL_CLASS_MAP from stream_app.utils import version_sidebar st.set_page_config( @@ -19,10 +20,10 @@ provider_status = {} model_types = [ # "vision", - "text generation", + "language", "embedding", - "text to speech", - "speech to text", + "text_to_speech", + "speech_to_text", ] provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None @@ -65,20 +66,30 @@ with model_tab: f"Unavailable Providers: {', '.join(unavailable_providers)}. Please check docs page if you wish to enable them." ) model_name = st.text_input("Model Name", "") - model_type = st.selectbox("Model Type", model_types) - if st.button("Save"): - model = Model(name=model_name, provider=provider, type=model_type) - model.save() - st.success("Saved") + + # Filter model types based on provider availability in MODEL_CLASS_MAP + available_model_types = [] + for model_type in model_types: + if model_type in MODEL_CLASS_MAP and provider in MODEL_CLASS_MAP[model_type]: + available_model_types.append(model_type) + + if not available_model_types: + st.error(f"No compatible model types available for provider: {provider}") + else: + model_type = st.selectbox("Model Type", available_model_types) + if st.button("Save"): + model = Model(name=model_name, provider=provider, type=model_type) + model.save() + st.success("Saved") st.divider() all_models = Model.get_all() st.subheader("Configured Models") model_types_available = { # "vision": False, - "text generation": False, + "language": False, "embedding": False, - "text to speech": False, - "speech to text": False, + "text_to_speech": False, + "speech_to_text": False, } for model in all_models: model_types_available[model.type] = True @@ -107,16 +118,14 @@ def get_selected_index(models, model_id, default=0): with model_defaults_tab: default_models = DefaultModels.load().model_dump() all_models = Model.get_all() - text_generation_models = [ - model for model in all_models if model.type == "text generation" - ] + text_generation_models = [model for model in all_models if model.type == "language"] text_to_speech_models = [ - model for model in all_models if model.type == "text to speech" + model for model in all_models if model.type == "text_to_speech" ] speech_to_text_models = [ - model for model in all_models if model.type == "speech to text" + model for model in all_models if model.type == "speech_to_text" ] vision_models = [model for model in all_models if model.type == "vision"] embedding_models = [model for model in all_models if model.type == "embedding"]