prevent users from picking a model type not supported by the provider
This commit is contained in:
parent
aca75221a7
commit
99f177044f
2 changed files with 26 additions and 16 deletions
|
|
@ -30,6 +30,7 @@ MODEL_CLASS_MAP = {
|
|||
"speech_to_text": {
|
||||
"openai": OpenAISpeechToTextModel,
|
||||
},
|
||||
"text_to_speech": {"openai": None, "elevenlabs": None},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import streamlit as st
|
||||
|
||||
from open_notebook.domain.models import DefaultModels, Model
|
||||
from open_notebook.models import MODEL_CLASS_MAP
|
||||
from stream_app.utils import version_sidebar
|
||||
|
||||
st.set_page_config(
|
||||
|
|
@ -19,10 +20,10 @@ provider_status = {}
|
|||
|
||||
model_types = [
|
||||
# "vision",
|
||||
"text generation",
|
||||
"language",
|
||||
"embedding",
|
||||
"text to speech",
|
||||
"speech to text",
|
||||
"text_to_speech",
|
||||
"speech_to_text",
|
||||
]
|
||||
|
||||
provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None
|
||||
|
|
@ -65,20 +66,30 @@ with model_tab:
|
|||
f"Unavailable Providers: {', '.join(unavailable_providers)}. Please check docs page if you wish to enable them."
|
||||
)
|
||||
model_name = st.text_input("Model Name", "")
|
||||
model_type = st.selectbox("Model Type", model_types)
|
||||
if st.button("Save"):
|
||||
model = Model(name=model_name, provider=provider, type=model_type)
|
||||
model.save()
|
||||
st.success("Saved")
|
||||
|
||||
# Filter model types based on provider availability in MODEL_CLASS_MAP
|
||||
available_model_types = []
|
||||
for model_type in model_types:
|
||||
if model_type in MODEL_CLASS_MAP and provider in MODEL_CLASS_MAP[model_type]:
|
||||
available_model_types.append(model_type)
|
||||
|
||||
if not available_model_types:
|
||||
st.error(f"No compatible model types available for provider: {provider}")
|
||||
else:
|
||||
model_type = st.selectbox("Model Type", available_model_types)
|
||||
if st.button("Save"):
|
||||
model = Model(name=model_name, provider=provider, type=model_type)
|
||||
model.save()
|
||||
st.success("Saved")
|
||||
st.divider()
|
||||
all_models = Model.get_all()
|
||||
st.subheader("Configured Models")
|
||||
model_types_available = {
|
||||
# "vision": False,
|
||||
"text generation": False,
|
||||
"language": False,
|
||||
"embedding": False,
|
||||
"text to speech": False,
|
||||
"speech to text": False,
|
||||
"text_to_speech": False,
|
||||
"speech_to_text": False,
|
||||
}
|
||||
for model in all_models:
|
||||
model_types_available[model.type] = True
|
||||
|
|
@ -107,16 +118,14 @@ def get_selected_index(models, model_id, default=0):
|
|||
with model_defaults_tab:
|
||||
default_models = DefaultModels.load().model_dump()
|
||||
all_models = Model.get_all()
|
||||
text_generation_models = [
|
||||
model for model in all_models if model.type == "text generation"
|
||||
]
|
||||
text_generation_models = [model for model in all_models if model.type == "language"]
|
||||
|
||||
text_to_speech_models = [
|
||||
model for model in all_models if model.type == "text to speech"
|
||||
model for model in all_models if model.type == "text_to_speech"
|
||||
]
|
||||
|
||||
speech_to_text_models = [
|
||||
model for model in all_models if model.type == "speech to text"
|
||||
model for model in all_models if model.type == "speech_to_text"
|
||||
]
|
||||
vision_models = [model for model in all_models if model.type == "vision"]
|
||||
embedding_models = [model for model in all_models if model.type == "embedding"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue