prevent users from picking a model type not supported by the provider

This commit is contained in:
LUIS NOVO 2024-10-30 14:44:48 -03:00
parent aca75221a7
commit 99f177044f
2 changed files with 26 additions and 16 deletions

View file

@ -30,6 +30,7 @@ MODEL_CLASS_MAP = {
"speech_to_text": {
"openai": OpenAISpeechToTextModel,
},
"text_to_speech": {"openai": None, "elevenlabs": None},
}

View file

@ -3,6 +3,7 @@ import os
import streamlit as st
from open_notebook.domain.models import DefaultModels, Model
from open_notebook.models import MODEL_CLASS_MAP
from stream_app.utils import version_sidebar
st.set_page_config(
@ -19,10 +20,10 @@ provider_status = {}
model_types = [
# "vision",
"text generation",
"language",
"embedding",
"text to speech",
"speech to text",
"text_to_speech",
"speech_to_text",
]
provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None
@ -65,20 +66,30 @@ with model_tab:
f"Unavailable Providers: {', '.join(unavailable_providers)}. Please check docs page if you wish to enable them."
)
model_name = st.text_input("Model Name", "")
model_type = st.selectbox("Model Type", model_types)
if st.button("Save"):
model = Model(name=model_name, provider=provider, type=model_type)
model.save()
st.success("Saved")
# Filter model types based on provider availability in MODEL_CLASS_MAP
available_model_types = []
for model_type in model_types:
if model_type in MODEL_CLASS_MAP and provider in MODEL_CLASS_MAP[model_type]:
available_model_types.append(model_type)
if not available_model_types:
st.error(f"No compatible model types available for provider: {provider}")
else:
model_type = st.selectbox("Model Type", available_model_types)
if st.button("Save"):
model = Model(name=model_name, provider=provider, type=model_type)
model.save()
st.success("Saved")
st.divider()
all_models = Model.get_all()
st.subheader("Configured Models")
model_types_available = {
# "vision": False,
"text generation": False,
"language": False,
"embedding": False,
"text to speech": False,
"speech to text": False,
"text_to_speech": False,
"speech_to_text": False,
}
for model in all_models:
model_types_available[model.type] = True
@ -107,16 +118,14 @@ def get_selected_index(models, model_id, default=0):
with model_defaults_tab:
default_models = DefaultModels.load().model_dump()
all_models = Model.get_all()
text_generation_models = [
model for model in all_models if model.type == "text generation"
]
text_generation_models = [model for model in all_models if model.type == "language"]
text_to_speech_models = [
model for model in all_models if model.type == "text to speech"
model for model in all_models if model.type == "text_to_speech"
]
speech_to_text_models = [
model for model in all_models if model.type == "speech to text"
model for model in all_models if model.type == "speech_to_text"
]
vision_models = [model for model in all_models if model.type == "vision"]
embedding_models = [model for model in all_models if model.type == "embedding"]