implement model config
This commit is contained in:
parent
63a568490e
commit
8bb5db158f
19 changed files with 434 additions and 105 deletions
|
|
@ -3,6 +3,10 @@ import os
|
|||
import yaml
|
||||
from loguru import logger
|
||||
|
||||
from open_notebook.domain.models import DefaultModels
|
||||
from open_notebook.models.embedding_models import get_embedding_model
|
||||
from open_notebook.models.speech_to_text_models import get_speech_to_text_model
|
||||
|
||||
# todo: enable config file overwrite with env vars
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
|
|
@ -32,3 +36,12 @@ os.makedirs(UPLOADS_FOLDER, exist_ok=True)
|
|||
# PODCASTS FOLDER
|
||||
PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts"
|
||||
os.makedirs(PODCASTS_FOLDER, exist_ok=True)
|
||||
|
||||
|
||||
DEFAULT_MODELS = DefaultModels.load()
|
||||
|
||||
EMBEDDING_MODEL = get_embedding_model(DEFAULT_MODELS.default_embedding_model)
|
||||
|
||||
SPEECH_TO_TEXT_MODEL = get_speech_to_text_model(
|
||||
DEFAULT_MODELS.default_speech_to_text_model
|
||||
)
|
||||
|
|
|
|||
38
open_notebook/domain/models.py
Normal file
38
open_notebook/domain/models.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from typing import ClassVar, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_notebook.domain.base import ObjectModel
|
||||
from open_notebook.repository import (
|
||||
repo_query,
|
||||
repo_update,
|
||||
)
|
||||
|
||||
|
||||
class Model(ObjectModel):
|
||||
table_name: ClassVar[str] = "model"
|
||||
name: str
|
||||
provider: str
|
||||
type: str
|
||||
|
||||
|
||||
class DefaultModels(BaseModel):
|
||||
default_chat_model: Optional[str] = None
|
||||
default_transformation_model: Optional[str] = None
|
||||
large_context_model: Optional[str] = None
|
||||
default_text_to_speech_model: Optional[str] = None
|
||||
default_speech_to_text_model: Optional[str] = None
|
||||
# default_vision_model: Optional[str] = None
|
||||
default_embedding_model: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def load(self):
|
||||
result = repo_query("SELECT * FROM open_notebook:default_models;")
|
||||
if result:
|
||||
logger.debug(result)
|
||||
return DefaultModels(**result[0])
|
||||
|
||||
@classmethod
|
||||
def update(self, data):
|
||||
repo_update("open_notebook:default_models", data)
|
||||
|
|
@ -9,8 +9,8 @@ from langgraph.graph import END, START, StateGraph
|
|||
from langgraph.graph.message import add_messages
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
|
||||
from open_notebook.domain import Notebook
|
||||
from open_notebook.config import DEFAULT_MODELS, LANGGRAPH_CHECKPOINT_FILE
|
||||
from open_notebook.domain.notebook import Notebook
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
|
||||
|
||||
|
|
@ -22,7 +22,9 @@ class ThreadState(TypedDict):
|
|||
|
||||
|
||||
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get("model_name", None)
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", DEFAULT_MODELS.default_chat_model
|
||||
)
|
||||
ai_message = run_pattern(
|
||||
"chat",
|
||||
model_name,
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ from math import ceil
|
|||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
|
||||
from open_notebook.config import SPEECH_TO_TEXT_MODEL
|
||||
from open_notebook.graphs.content_processing.state import SourceState
|
||||
|
||||
# todo: add a speechtotext model to the config
|
||||
# future: parallelize the transcription process
|
||||
|
||||
|
||||
|
|
@ -73,9 +73,6 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
|
|||
|
||||
def extract_audio(data: SourceState):
|
||||
input_audio_path = data.get("file_path")
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI()
|
||||
audio_files = []
|
||||
|
||||
try:
|
||||
|
|
@ -83,11 +80,7 @@ def extract_audio(data: SourceState):
|
|||
transcriptions = []
|
||||
|
||||
for audio_file in audio_files:
|
||||
with open(audio_file, "rb") as audio:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model="whisper-1", file=audio
|
||||
)
|
||||
transcriptions.append(transcription.text)
|
||||
transcriptions.append(SPEECH_TO_TEXT_MODEL.transcribe(audio_file))
|
||||
|
||||
return {"content": " ".join(transcriptions)}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from langchain_core.runnables import (
|
|||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.domain import Note, Notebook, Source
|
||||
from open_notebook.domain.notebook import Note, Notebook, Source
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import operator
|
||||
import os
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from langchain_core.runnables import (
|
||||
|
|
@ -8,6 +7,7 @@ from langchain_core.runnables import (
|
|||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from open_notebook.config import DEFAULT_MODELS
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ class PatternChainState(TypedDict):
|
|||
|
||||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", os.environ.get("DEFAULT_MODEL")
|
||||
"model_name", DEFAULT_MODELS.default_transformation_model
|
||||
)
|
||||
transformations = state["transformations"]
|
||||
current_transformation = transformations.pop(0)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import os
|
||||
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
)
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.config import DEFAULT_MODELS
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
|
||||
|
||||
|
|
@ -17,7 +16,7 @@ class PatternState(TypedDict):
|
|||
|
||||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", os.environ.get("DEFAULT_MODEL")
|
||||
"model_name", DEFAULT_MODELS.default_transformation_model
|
||||
)
|
||||
return {
|
||||
"output": run_pattern(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from langchain_core.runnables import (
|
|||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.config import DEFAULT_MODELS
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
from open_notebook.utils import split_text
|
||||
|
||||
|
|
@ -49,7 +50,7 @@ def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: igno
|
|||
|
||||
def call_model(state: TocState, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", os.environ.get("SUMMARIZATION_MODEL")
|
||||
"model_name", DEFAULT_MODELS.default_transformation_model
|
||||
)
|
||||
return {
|
||||
"toc": run_pattern(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langgraph.graph import END, START, StateGraph
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.config import DEFAULT_MODELS
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
from open_notebook.utils import split_text
|
||||
|
||||
|
|
@ -57,9 +58,9 @@ def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type:
|
|||
return END
|
||||
|
||||
|
||||
def call_model(state: SummaryState, config: RunnableConfig) -> dict:
|
||||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", os.environ.get("SUMMARIZATION_MODEL")
|
||||
"model_name", DEFAULT_MODELS.default_transformation_model
|
||||
)
|
||||
parser = PydanticOutputParser(pydantic_object=SummaryResponse)
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import os
|
||||
|
||||
from langchain.output_parsers import OutputFixingParser
|
||||
|
||||
from open_notebook.llm_router import get_langchain_model
|
||||
from open_notebook.config import DEFAULT_MODELS
|
||||
from open_notebook.models.llms import get_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
||||
|
|
@ -15,7 +14,7 @@ def run_pattern(
|
|||
output_fixing_model_name=None,
|
||||
) -> dict:
|
||||
if not model_name:
|
||||
model_name = os.environ["DEFAULT_MODEL"]
|
||||
model_name = DEFAULT_MODELS.default_transformation_model
|
||||
|
||||
chain = get_langchain_model(model_name)
|
||||
if parser:
|
||||
|
|
|
|||
|
|
@ -1,35 +0,0 @@
|
|||
from open_notebook.llms import (
|
||||
AnthropicLanguageModel,
|
||||
GeminiLanguageModel,
|
||||
LiteLLMLanguageModel,
|
||||
OllamaLanguageModel,
|
||||
OpenAILanguageModel,
|
||||
OpenRouterLanguageModel,
|
||||
VertexAILanguageModel,
|
||||
VertexAnthropicLanguageModel,
|
||||
)
|
||||
|
||||
# Map provider names to classes
|
||||
PROVIDER_CLASS_MAP = {
|
||||
"ollama": OllamaLanguageModel,
|
||||
"openrouter": OpenRouterLanguageModel,
|
||||
"vertexai-anthropic": VertexAnthropicLanguageModel,
|
||||
"litellm": LiteLLMLanguageModel,
|
||||
"vertexai": VertexAILanguageModel,
|
||||
"anthropic": AnthropicLanguageModel,
|
||||
"openai": OpenAILanguageModel,
|
||||
"gemini": GeminiLanguageModel,
|
||||
}
|
||||
|
||||
|
||||
def get_langchain_model(model_name, json=False):
|
||||
parts = model_name.split("/")
|
||||
provider = parts[0]
|
||||
model_name_wihout_provider = "/".join(parts[1:])
|
||||
if provider not in PROVIDER_CLASS_MAP.keys():
|
||||
raise ValueError(
|
||||
f"Provider {provider} not found in config. Make sure you use the correct format for model names, example: openai/gpt-4o-mini"
|
||||
)
|
||||
return PROVIDER_CLASS_MAP[provider](
|
||||
model_name=model_name_wihout_provider, json=json
|
||||
).to_langchain()
|
||||
0
open_notebook/models/__init__.py
Normal file
0
open_notebook/models/__init__.py
Normal file
62
open_notebook/models/embedding_models.py
Normal file
62
open_notebook/models/embedding_models.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
Classes for supporting different embedding models
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModel(ABC):
|
||||
"""
|
||||
Abstract base class for language models.
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generates an embedding
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIEmbeddingModel(EmbeddingModel):
|
||||
model_name: str
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""
|
||||
Embeds the content using Open AI embedding
|
||||
"""
|
||||
# todo: make this Singleton
|
||||
client = OpenAI()
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
client.embeddings.create(input=[text], model=self.model_name)
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
|
||||
|
||||
EMBEDDING_CLASS_MAP = {
|
||||
"openai": OpenAIEmbeddingModel,
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model(model_id):
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
model = Model.get(model_id)
|
||||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
if model.provider not in EMBEDDING_CLASS_MAP.keys():
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with Embedding Models"
|
||||
)
|
||||
return EMBEDDING_CLASS_MAP[model.provider](model_name=model.name)
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Classes for supporting different language and vector models
|
||||
Classes for supporting different language models
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -15,9 +15,9 @@ from langchain_google_vertexai import ChatVertexAI
|
|||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
# from redisvl.utils.vectorize import BaseVectorizer
|
||||
# from redisvl.utils.vectorize.text.openai import OpenAITextVectorizer
|
||||
from open_notebook.domain.models import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -186,7 +186,7 @@ class OpenRouterLanguageModel(LanguageModel):
|
|||
max_tokens=self.max_tokens,
|
||||
model_kwargs=kwargs,
|
||||
streaming=self.streaming,
|
||||
api_key=os.environ.get("OPENROUTER_API_KEY", "openrouter"),
|
||||
api_key=SecretStr(os.environ.get("OPENROUTER_API_KEY", "openrouter")),
|
||||
top_p=self.top_p,
|
||||
)
|
||||
|
||||
|
|
@ -240,26 +240,26 @@ class OpenAILanguageModel(LanguageModel):
|
|||
)
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class EmbeddingModel(ABC):
|
||||
# model_name: str
|
||||
# dimensions: int
|
||||
|
||||
# def to_redis_vectorizer(self) -> BaseVectorizer:
|
||||
# raise NotImplementedError
|
||||
# Map provider names to classes
|
||||
PROVIDER_CLASS_MAP = {
|
||||
"ollama": OllamaLanguageModel,
|
||||
"openrouter": OpenRouterLanguageModel,
|
||||
"vertexai-anthropic": VertexAnthropicLanguageModel,
|
||||
"litellm": LiteLLMLanguageModel,
|
||||
"vertexai": VertexAILanguageModel,
|
||||
"anthropic": AnthropicLanguageModel,
|
||||
"openai": OpenAILanguageModel,
|
||||
"gemini": GeminiLanguageModel,
|
||||
}
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class OpenAIEmbeddingModel(EmbeddingModel):
|
||||
# """
|
||||
# Embedding model that uses the OpenAI text embedding model.
|
||||
# """
|
||||
|
||||
# model_name: str
|
||||
# dimensions: int
|
||||
|
||||
# def to_redis_vectorizer(self) -> OpenAITextVectorizer:
|
||||
# """
|
||||
# Convert the embedding model to a Redis vectorizer.
|
||||
# """
|
||||
# return OpenAITextVectorizer(model=self.model_name)
|
||||
# todo: make the provider check type specific
|
||||
def get_langchain_model(model_id, json=False):
|
||||
model = Model.get(model_id)
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
if model.provider not in PROVIDER_CLASS_MAP.keys():
|
||||
raise ValueError(f"Provider {model.provider} not found")
|
||||
return PROVIDER_CLASS_MAP[model.provider](
|
||||
model_name=model.name, json=json
|
||||
).to_langchain()
|
||||
62
open_notebook/models/speech_to_text_models.py
Normal file
62
open_notebook/models/speech_to_text_models.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
Classes for supporting different transcription models
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechToTextModel(ABC):
|
||||
"""
|
||||
Abstract base class for speech to text models.
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def transcribe(self, audio_file_path: str) -> str:
|
||||
"""
|
||||
Generates a text transcription from audio
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAISpeechToTextModel(SpeechToTextModel):
|
||||
model_name: str
|
||||
|
||||
def transcribe(self, audio_file_path: str) -> str:
|
||||
"""
|
||||
Transcribes an audio file into text
|
||||
"""
|
||||
# todo: make this Singleton
|
||||
client = OpenAI()
|
||||
with open(audio_file_path, "rb") as audio:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=self.model_name, file=audio
|
||||
)
|
||||
return transcription.text
|
||||
|
||||
|
||||
SPEECH_TO_TEXT_CLASS_MAP = {
|
||||
"openai": OpenAISpeechToTextModel,
|
||||
}
|
||||
|
||||
|
||||
# todo: acho que dá pra juntar todos os get models em uma coisa só
|
||||
def get_speech_to_text_model(model_id):
|
||||
assert model_id, "Model ID cannot be empty"
|
||||
model = Model.get(model_id)
|
||||
if not model:
|
||||
raise ValueError(f"Model with ID {model_id} not found")
|
||||
if model.provider not in SPEECH_TO_TEXT_CLASS_MAP.keys():
|
||||
raise ValueError(
|
||||
f"Provider {model.provider} not compatible with Embedding Models"
|
||||
)
|
||||
return SPEECH_TO_TEXT_CLASS_MAP[model.provider](model_name=model.name)
|
||||
|
|
@ -4,7 +4,7 @@ from loguru import logger
|
|||
from podcastfy.client import generate_podcast
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from open_notebook.domain import ObjectModel
|
||||
from open_notebook.domain.notebook import ObjectModel
|
||||
|
||||
|
||||
class PodcastEpisode(ObjectModel):
|
||||
|
|
|
|||
|
|
@ -6,11 +6,8 @@ from urllib.parse import urlparse
|
|||
import requests
|
||||
import tomli
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
from openai import OpenAI
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def split_text(txt: str, chunk=1000, overlap=0, separator=" "):
|
||||
"""
|
||||
|
|
@ -63,21 +60,6 @@ def token_cost(token_count, cost_per_million=0.150):
|
|||
return cost_per_million * (token_count / 1_000_000)
|
||||
|
||||
|
||||
def get_embedding(text, model="text-embedding-3-small"):
|
||||
"""
|
||||
Get the embedding for the input text using the specified model.
|
||||
|
||||
Args:
|
||||
text (str): The input text to get the embedding for.
|
||||
model (str): The name of the embedding model to use. Default is "text-embedding-3-small".
|
||||
|
||||
Returns:
|
||||
list: The embedding vector for the input text.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return client.embeddings.create(input=[text], model=model).data[0].embedding
|
||||
|
||||
|
||||
def remove_non_ascii(text):
|
||||
return re.sub(r"[^\x00-\x7F]+", "", text)
|
||||
|
||||
|
|
|
|||
212
pages/9_⚙️_Settings.py
Normal file
212
pages/9_⚙️_Settings.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from open_notebook.domain.models import DefaultModels, Model
|
||||
from stream_app.utils import version_sidebar
|
||||
|
||||
st.set_page_config(
|
||||
layout="wide", page_title="⚙️ Settings", initial_sidebar_state="expanded"
|
||||
)
|
||||
version_sidebar()
|
||||
|
||||
|
||||
st.title("Settings")
|
||||
|
||||
model_tab, model_defaults_tab = st.tabs(["Models", "Model Defaults"])
|
||||
|
||||
provider_status = {}
|
||||
|
||||
model_types = [
|
||||
# "vision",
|
||||
"text generation",
|
||||
"embedding",
|
||||
"text to speech",
|
||||
"speech to text",
|
||||
]
|
||||
|
||||
provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None
|
||||
provider_status["openai"] = os.environ.get("OPENAI_API_KEY") is not None
|
||||
provider_status["vertexai"] = (
|
||||
os.environ.get("VERTEX_PROJECT") is not None
|
||||
and os.environ.get("VERTEX_LOCATION") is not None
|
||||
and os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None
|
||||
)
|
||||
provider_status["vertexai-anthropic"] = (
|
||||
os.environ.get("VERTEX_PROJECT") is not None
|
||||
and os.environ.get("VERTEX_LOCATION") is not None
|
||||
and os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None
|
||||
)
|
||||
provider_status["gemini"] = os.environ.get("GEMINI_API_KEY") is not None
|
||||
provider_status["openrouter"] = (
|
||||
os.environ.get("OPENROUTER_API_KEY") is not None
|
||||
and os.environ.get("OPENAI_API_KEY") is not None
|
||||
and os.environ.get("OPENROUTER_BASE_URL") is not None
|
||||
)
|
||||
provider_status["anthropic"] = os.environ.get("ANTHROPIC_API_KEY") is not None
|
||||
provider_status["eleven_labs"] = os.environ.get("ELEVENLABS_API_KEY") is not None
|
||||
provider_status["litellm"] = (
|
||||
provider_status["ollama"]
|
||||
or provider_status["vertexai"]
|
||||
or provider_status["vertexai-anthropic"]
|
||||
or provider_status["anthropic"]
|
||||
or provider_status["openai"]
|
||||
or provider_status["gemini"]
|
||||
)
|
||||
|
||||
available_providers = [k for k, v in provider_status.items() if v]
|
||||
unavailable_providers = [k for k, v in provider_status.items() if not v]
|
||||
|
||||
with model_tab:
|
||||
st.subheader("Add Model")
|
||||
provider = st.selectbox("Provider", available_providers)
|
||||
if len(unavailable_providers) > 0:
|
||||
st.caption(
|
||||
f"Unavailable Providers: {', '.join(unavailable_providers)}. Please check docs page if you wish to enable them."
|
||||
)
|
||||
model_name = st.text_input("Model Name", "")
|
||||
model_type = st.selectbox("Model Type", model_types)
|
||||
if st.button("Save"):
|
||||
model = Model(name=model_name, provider=provider, type=model_type)
|
||||
model.save()
|
||||
st.success("Saved")
|
||||
st.divider()
|
||||
all_models = Model.get_all()
|
||||
st.subheader("Configured Models")
|
||||
model_types_available = {
|
||||
# "vision": False,
|
||||
"text generation": False,
|
||||
"embedding": False,
|
||||
"text to speech": False,
|
||||
"speech to text": False,
|
||||
}
|
||||
for model in all_models:
|
||||
model_types_available[model.type] = True
|
||||
with st.container(border=True):
|
||||
st.markdown(f"{model.name} ({model.provider}, {model.type})")
|
||||
if st.button("Delete", key=f"delete_{model.id}"):
|
||||
model.delete()
|
||||
st.rerun()
|
||||
|
||||
for model_type, available in model_types_available.items():
|
||||
if not available:
|
||||
st.warning(f"No models available for {model_type}")
|
||||
|
||||
|
||||
# todo: check for each type of model
|
||||
def get_selected_index(models, model_id, default=0):
|
||||
"""Returns the index of the selected model in the list of models"""
|
||||
if not model_id or not models:
|
||||
return default
|
||||
for i, model in enumerate(models):
|
||||
if model.id == model_id:
|
||||
return i
|
||||
return default
|
||||
|
||||
|
||||
with model_defaults_tab:
|
||||
default_models = DefaultModels.load().model_dump()
|
||||
all_models = Model.get_all()
|
||||
text_generation_models = [
|
||||
model for model in all_models if model.type == "text generation"
|
||||
]
|
||||
|
||||
text_to_speech_models = [
|
||||
model for model in all_models if model.type == "text to speech"
|
||||
]
|
||||
|
||||
speech_to_text_models = [
|
||||
model for model in all_models if model.type == "speech to text"
|
||||
]
|
||||
vision_models = [model for model in all_models if model.type == "vision"]
|
||||
embedding_models = [model for model in all_models if model.type == "embedding"]
|
||||
st.write(
|
||||
"In this section, you can select the default models to be used on the various content operations done by Open Notebook. Some of these can be overriden in the different modules."
|
||||
)
|
||||
defs = {}
|
||||
defs["default_chat_model"] = st.selectbox(
|
||||
"Default Chat Model",
|
||||
text_generation_models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This model will be used for chat.",
|
||||
index=get_selected_index(
|
||||
text_generation_models, default_models.get("default_chat_model")
|
||||
),
|
||||
)
|
||||
st.divider()
|
||||
defs["default_transformation_model"] = st.selectbox(
|
||||
"Default Transformation Model",
|
||||
text_generation_models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This model will be used for text transformations such as summaries, insights, etc.",
|
||||
index=get_selected_index(
|
||||
text_generation_models, default_models.get("default_transformation_model")
|
||||
),
|
||||
)
|
||||
st.caption("You can override this model on individual transformations")
|
||||
st.divider()
|
||||
defs["large_context_model"] = st.selectbox(
|
||||
"Large Context Model",
|
||||
text_generation_models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This model will be used for larger context generation -- recommended: Gemini",
|
||||
index=get_selected_index(
|
||||
text_generation_models, default_models.get("large_context_model")
|
||||
),
|
||||
)
|
||||
st.caption("Recommended to use Gemini models for larger context processing")
|
||||
st.divider()
|
||||
defs["default_text_to_speech_model"] = st.selectbox(
|
||||
"Default Text to Speech Model",
|
||||
text_to_speech_models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This is the default model for converting text to speech (podcasts, etc)",
|
||||
index=get_selected_index(
|
||||
text_to_speech_models, default_models.get("default_text_to_speech_model")
|
||||
),
|
||||
)
|
||||
st.caption("You can override this model on different podcasts")
|
||||
st.divider()
|
||||
defs["default_speech_to_text_model"] = st.selectbox(
|
||||
"Default Speech to Text Model",
|
||||
speech_to_text_models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This is the default model for converting speech to text (audio transcriptions, etc)",
|
||||
index=get_selected_index(
|
||||
speech_to_text_models, default_models.get("default_speech_to_text_model")
|
||||
),
|
||||
)
|
||||
st.divider()
|
||||
# defs["default_vision_model"] = st.selectbox(
|
||||
# "Default Vision Model",
|
||||
# vision_models,
|
||||
# format_func=lambda x: x.name,
|
||||
# help="This is the default model for vision tasks (image recognition, PDF recognition, etc)",
|
||||
# index=get_selected_index(
|
||||
# vision_models, default_models.get("default_vision_model")
|
||||
# ),
|
||||
# )
|
||||
# st.divider()
|
||||
|
||||
defs["default_embedding_model"] = st.selectbox(
|
||||
"Default Embedding Model",
|
||||
embedding_models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This is the default model for embeddings (semantic search, etc)",
|
||||
index=get_selected_index(
|
||||
embedding_models, default_models.get("default_embedding_model")
|
||||
),
|
||||
)
|
||||
st.caption(
|
||||
"Caution: you cannot change the embedding model once there is embeddings or they will need to be regenerated"
|
||||
)
|
||||
|
||||
for k, v in defs.items():
|
||||
defs[k] = v.id
|
||||
|
||||
if st.button("Save Defaults", key="save_defaults"):
|
||||
DefaultModels.update(defs)
|
||||
st.rerun()
|
||||
|
||||
# todo: return an error if a selected model is no longer supported
|
||||
# todo: do this check on the app homepage as well
|
||||
|
|
@ -8,7 +8,7 @@ from humanize import naturaltime
|
|||
from loguru import logger
|
||||
|
||||
from open_notebook.config import UPLOADS_FOLDER
|
||||
from open_notebook.domain import Asset, Source
|
||||
from open_notebook.domain.notebook import Asset, Source
|
||||
from open_notebook.exceptions import UnsupportedTypeException
|
||||
from open_notebook.graphs.content_processing import graph
|
||||
from open_notebook.graphs.multipattern import graph as transform_graph
|
||||
|
|
|
|||
Loading…
Reference in a new issue