implement model config

This commit is contained in:
LUIS NOVO 2024-10-30 14:09:24 -03:00
parent 63a568490e
commit 8bb5db158f
19 changed files with 434 additions and 105 deletions

View file

@ -3,6 +3,10 @@ import os
import yaml
from loguru import logger
from open_notebook.domain.models import DefaultModels
from open_notebook.models.embedding_models import get_embedding_model
from open_notebook.models.speech_to_text_models import get_speech_to_text_model
# todo: enable config file overwrite with env vars
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
@ -32,3 +36,12 @@ os.makedirs(UPLOADS_FOLDER, exist_ok=True)
# PODCASTS FOLDER
PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts"
os.makedirs(PODCASTS_FOLDER, exist_ok=True)
DEFAULT_MODELS = DefaultModels.load()
EMBEDDING_MODEL = get_embedding_model(DEFAULT_MODELS.default_embedding_model)
SPEECH_TO_TEXT_MODEL = get_speech_to_text_model(
DEFAULT_MODELS.default_speech_to_text_model
)

View file

@ -0,0 +1,38 @@
from typing import ClassVar, Optional
from loguru import logger
from pydantic import BaseModel
from open_notebook.domain.base import ObjectModel
from open_notebook.repository import (
repo_query,
repo_update,
)
class Model(ObjectModel):
table_name: ClassVar[str] = "model"
name: str
provider: str
type: str
class DefaultModels(BaseModel):
default_chat_model: Optional[str] = None
default_transformation_model: Optional[str] = None
large_context_model: Optional[str] = None
default_text_to_speech_model: Optional[str] = None
default_speech_to_text_model: Optional[str] = None
# default_vision_model: Optional[str] = None
default_embedding_model: Optional[str] = None
@classmethod
def load(self):
result = repo_query("SELECT * FROM open_notebook:default_models;")
if result:
logger.debug(result)
return DefaultModels(**result[0])
@classmethod
def update(self, data):
repo_update("open_notebook:default_models", data)

View file

@ -9,8 +9,8 @@ from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
from open_notebook.domain import Notebook
from open_notebook.config import DEFAULT_MODELS, LANGGRAPH_CHECKPOINT_FILE
from open_notebook.domain.notebook import Notebook
from open_notebook.graphs.utils import run_pattern
@ -22,7 +22,9 @@ class ThreadState(TypedDict):
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get("model_name", None)
model_name = config.get("configurable", {}).get(
"model_name", DEFAULT_MODELS.default_chat_model
)
ai_message = run_pattern(
"chat",
model_name,

View file

@ -4,9 +4,9 @@ from math import ceil
from loguru import logger
from pydub import AudioSegment
from open_notebook.config import SPEECH_TO_TEXT_MODEL
from open_notebook.graphs.content_processing.state import SourceState
# todo: add a speechtotext model to the config
# future: parallelize the transcription process
@ -73,9 +73,6 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
def extract_audio(data: SourceState):
input_audio_path = data.get("file_path")
from openai import OpenAI
client = OpenAI()
audio_files = []
try:
@ -83,11 +80,7 @@ def extract_audio(data: SourceState):
transcriptions = []
for audio_file in audio_files:
with open(audio_file, "rb") as audio:
transcription = client.audio.transcriptions.create(
model="whisper-1", file=audio
)
transcriptions.append(transcription.text)
transcriptions.append(SPEECH_TO_TEXT_MODEL.transcribe(audio_file))
return {"content": " ".join(transcriptions)}

View file

@ -6,7 +6,7 @@ from langchain_core.runnables import (
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict
from open_notebook.domain import Note, Notebook, Source
from open_notebook.domain.notebook import Note, Notebook, Source
from open_notebook.graphs.utils import run_pattern

View file

@ -1,5 +1,4 @@
import operator
import os
from typing import List, Literal, Sequence
from langchain_core.runnables import (
@ -8,6 +7,7 @@ from langchain_core.runnables import (
from langgraph.graph import END, START, StateGraph
from typing_extensions import Annotated, TypedDict
from open_notebook.config import DEFAULT_MODELS
from open_notebook.graphs.utils import run_pattern
@ -19,7 +19,7 @@ class PatternChainState(TypedDict):
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", os.environ.get("DEFAULT_MODEL")
"model_name", DEFAULT_MODELS.default_transformation_model
)
transformations = state["transformations"]
current_transformation = transformations.pop(0)

View file

@ -1,11 +1,10 @@
import os
from langchain_core.runnables import (
RunnableConfig,
)
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict
from open_notebook.config import DEFAULT_MODELS
from open_notebook.graphs.utils import run_pattern
@ -17,7 +16,7 @@ class PatternState(TypedDict):
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", os.environ.get("DEFAULT_MODEL")
"model_name", DEFAULT_MODELS.default_transformation_model
)
return {
"output": run_pattern(

View file

@ -7,6 +7,7 @@ from langchain_core.runnables import (
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict
from open_notebook.config import DEFAULT_MODELS
from open_notebook.graphs.utils import run_pattern
from open_notebook.utils import split_text
@ -49,7 +50,7 @@ def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: igno
def call_model(state: TocState, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", os.environ.get("SUMMARIZATION_MODEL")
"model_name", DEFAULT_MODELS.default_transformation_model
)
return {
"toc": run_pattern(

View file

@ -9,6 +9,7 @@ from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel
from typing_extensions import TypedDict
from open_notebook.config import DEFAULT_MODELS
from open_notebook.graphs.utils import run_pattern
from open_notebook.utils import split_text
@ -57,9 +58,9 @@ def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type:
return END
def call_model(state: SummaryState, config: RunnableConfig) -> dict:
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", os.environ.get("SUMMARIZATION_MODEL")
"model_name", DEFAULT_MODELS.default_transformation_model
)
parser = PydanticOutputParser(pydantic_object=SummaryResponse)
return {

View file

@ -1,8 +1,7 @@
import os
from langchain.output_parsers import OutputFixingParser
from open_notebook.llm_router import get_langchain_model
from open_notebook.config import DEFAULT_MODELS
from open_notebook.models.llms import get_langchain_model
from open_notebook.prompter import Prompter
@ -15,7 +14,7 @@ def run_pattern(
output_fixing_model_name=None,
) -> dict:
if not model_name:
model_name = os.environ["DEFAULT_MODEL"]
model_name = DEFAULT_MODELS.default_transformation_model
chain = get_langchain_model(model_name)
if parser:

View file

@ -1,35 +0,0 @@
from open_notebook.llms import (
AnthropicLanguageModel,
GeminiLanguageModel,
LiteLLMLanguageModel,
OllamaLanguageModel,
OpenAILanguageModel,
OpenRouterLanguageModel,
VertexAILanguageModel,
VertexAnthropicLanguageModel,
)
# Map provider names to classes
PROVIDER_CLASS_MAP = {
"ollama": OllamaLanguageModel,
"openrouter": OpenRouterLanguageModel,
"vertexai-anthropic": VertexAnthropicLanguageModel,
"litellm": LiteLLMLanguageModel,
"vertexai": VertexAILanguageModel,
"anthropic": AnthropicLanguageModel,
"openai": OpenAILanguageModel,
"gemini": GeminiLanguageModel,
}
def get_langchain_model(model_name, json=False):
parts = model_name.split("/")
provider = parts[0]
model_name_wihout_provider = "/".join(parts[1:])
if provider not in PROVIDER_CLASS_MAP.keys():
raise ValueError(
f"Provider {provider} not found in config. Make sure you use the correct format for model names, example: openai/gpt-4o-mini"
)
return PROVIDER_CLASS_MAP[provider](
model_name=model_name_wihout_provider, json=json
).to_langchain()

View file

View file

@ -0,0 +1,62 @@
"""
Classes for supporting different embedding models
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from openai import OpenAI
from open_notebook.domain.models import Model
@dataclass
class EmbeddingModel(ABC):
"""
Abstract base class for language models.
"""
model_name: Optional[str] = None
@abstractmethod
def embed(self, text: str) -> List[float]:
"""
Generates an embedding
"""
raise NotImplementedError
@dataclass
class OpenAIEmbeddingModel(EmbeddingModel):
model_name: str
def embed(self, text: str) -> List[float]:
"""
Embeds the content using Open AI embedding
"""
# todo: make this Singleton
client = OpenAI()
text = text.replace("\n", " ")
return (
client.embeddings.create(input=[text], model=self.model_name)
.data[0]
.embedding
)
EMBEDDING_CLASS_MAP = {
"openai": OpenAIEmbeddingModel,
}
def get_embedding_model(model_id):
assert model_id, "Model ID cannot be empty"
model = Model.get(model_id)
if not model:
raise ValueError(f"Model with ID {model_id} not found")
if model.provider not in EMBEDDING_CLASS_MAP.keys():
raise ValueError(
f"Provider {model.provider} not compatible with Embedding Models"
)
return EMBEDDING_CLASS_MAP[model.provider](model_name=model.name)

View file

@ -1,5 +1,5 @@
"""
Classes for supporting different language and vector models
Classes for supporting different language models
"""
import os
@ -15,9 +15,9 @@ from langchain_google_vertexai import ChatVertexAI
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
from langchain_ollama.chat_models import ChatOllama
from langchain_openai.chat_models import ChatOpenAI
from pydantic import SecretStr
# from redisvl.utils.vectorize import BaseVectorizer
# from redisvl.utils.vectorize.text.openai import OpenAITextVectorizer
from open_notebook.domain.models import Model
@dataclass
@ -186,7 +186,7 @@ class OpenRouterLanguageModel(LanguageModel):
max_tokens=self.max_tokens,
model_kwargs=kwargs,
streaming=self.streaming,
api_key=os.environ.get("OPENROUTER_API_KEY", "openrouter"),
api_key=SecretStr(os.environ.get("OPENROUTER_API_KEY", "openrouter")),
top_p=self.top_p,
)
@ -240,26 +240,26 @@ class OpenAILanguageModel(LanguageModel):
)
# @dataclass
# class EmbeddingModel(ABC):
# model_name: str
# dimensions: int
# def to_redis_vectorizer(self) -> BaseVectorizer:
# raise NotImplementedError
# Map provider names to classes
PROVIDER_CLASS_MAP = {
"ollama": OllamaLanguageModel,
"openrouter": OpenRouterLanguageModel,
"vertexai-anthropic": VertexAnthropicLanguageModel,
"litellm": LiteLLMLanguageModel,
"vertexai": VertexAILanguageModel,
"anthropic": AnthropicLanguageModel,
"openai": OpenAILanguageModel,
"gemini": GeminiLanguageModel,
}
# @dataclass
# class OpenAIEmbeddingModel(EmbeddingModel):
# """
# Embedding model that uses the OpenAI text embedding model.
# """
# model_name: str
# dimensions: int
# def to_redis_vectorizer(self) -> OpenAITextVectorizer:
# """
# Convert the embedding model to a Redis vectorizer.
# """
# return OpenAITextVectorizer(model=self.model_name)
# todo: make the provider check type specific
def get_langchain_model(model_id, json=False):
model = Model.get(model_id)
if not model:
raise ValueError(f"Model {model_id} not found")
if model.provider not in PROVIDER_CLASS_MAP.keys():
raise ValueError(f"Provider {model.provider} not found")
return PROVIDER_CLASS_MAP[model.provider](
model_name=model.name, json=json
).to_langchain()

View file

@ -0,0 +1,62 @@
"""
Classes for supporting different transcription models
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from openai import OpenAI
from open_notebook.domain.models import Model
@dataclass
class SpeechToTextModel(ABC):
"""
Abstract base class for speech to text models.
"""
model_name: Optional[str] = None
@abstractmethod
def transcribe(self, audio_file_path: str) -> str:
"""
Generates a text transcription from audio
"""
raise NotImplementedError
@dataclass
class OpenAISpeechToTextModel(SpeechToTextModel):
model_name: str
def transcribe(self, audio_file_path: str) -> str:
"""
Transcribes an audio file into text
"""
# todo: make this Singleton
client = OpenAI()
with open(audio_file_path, "rb") as audio:
transcription = client.audio.transcriptions.create(
model=self.model_name, file=audio
)
return transcription.text
SPEECH_TO_TEXT_CLASS_MAP = {
"openai": OpenAISpeechToTextModel,
}
# todo: acho que dá pra juntar todos os get models em uma coisa só
def get_speech_to_text_model(model_id):
assert model_id, "Model ID cannot be empty"
model = Model.get(model_id)
if not model:
raise ValueError(f"Model with ID {model_id} not found")
if model.provider not in SPEECH_TO_TEXT_CLASS_MAP.keys():
raise ValueError(
f"Provider {model.provider} not compatible with Embedding Models"
)
return SPEECH_TO_TEXT_CLASS_MAP[model.provider](model_name=model.name)

View file

@ -4,7 +4,7 @@ from loguru import logger
from podcastfy.client import generate_podcast
from pydantic import Field, field_validator
from open_notebook.domain import ObjectModel
from open_notebook.domain.notebook import ObjectModel
class PodcastEpisode(ObjectModel):

View file

@ -6,11 +6,8 @@ from urllib.parse import urlparse
import requests
import tomli
from langchain_text_splitters import CharacterTextSplitter
from openai import OpenAI
from packaging.version import parse as parse_version
client = OpenAI()
def split_text(txt: str, chunk=1000, overlap=0, separator=" "):
"""
@ -63,21 +60,6 @@ def token_cost(token_count, cost_per_million=0.150):
return cost_per_million * (token_count / 1_000_000)
def get_embedding(text, model="text-embedding-3-small"):
"""
Get the embedding for the input text using the specified model.
Args:
text (str): The input text to get the embedding for.
model (str): The name of the embedding model to use. Default is "text-embedding-3-small".
Returns:
list: The embedding vector for the input text.
"""
text = text.replace("\n", " ")
return client.embeddings.create(input=[text], model=model).data[0].embedding
def remove_non_ascii(text):
return re.sub(r"[^\x00-\x7F]+", "", text)

212
pages/9_⚙️_Settings.py Normal file
View file

@ -0,0 +1,212 @@
import os
import streamlit as st
from open_notebook.domain.models import DefaultModels, Model
from stream_app.utils import version_sidebar
st.set_page_config(
layout="wide", page_title="⚙️ Settings", initial_sidebar_state="expanded"
)
version_sidebar()
st.title("Settings")
model_tab, model_defaults_tab = st.tabs(["Models", "Model Defaults"])
provider_status = {}
model_types = [
# "vision",
"text generation",
"embedding",
"text to speech",
"speech to text",
]
provider_status["ollama"] = os.environ.get("OLLAMA_API_BASE") is not None
provider_status["openai"] = os.environ.get("OPENAI_API_KEY") is not None
provider_status["vertexai"] = (
os.environ.get("VERTEX_PROJECT") is not None
and os.environ.get("VERTEX_LOCATION") is not None
and os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None
)
provider_status["vertexai-anthropic"] = (
os.environ.get("VERTEX_PROJECT") is not None
and os.environ.get("VERTEX_LOCATION") is not None
and os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None
)
provider_status["gemini"] = os.environ.get("GEMINI_API_KEY") is not None
provider_status["openrouter"] = (
os.environ.get("OPENROUTER_API_KEY") is not None
and os.environ.get("OPENAI_API_KEY") is not None
and os.environ.get("OPENROUTER_BASE_URL") is not None
)
provider_status["anthropic"] = os.environ.get("ANTHROPIC_API_KEY") is not None
provider_status["eleven_labs"] = os.environ.get("ELEVENLABS_API_KEY") is not None
provider_status["litellm"] = (
provider_status["ollama"]
or provider_status["vertexai"]
or provider_status["vertexai-anthropic"]
or provider_status["anthropic"]
or provider_status["openai"]
or provider_status["gemini"]
)
available_providers = [k for k, v in provider_status.items() if v]
unavailable_providers = [k for k, v in provider_status.items() if not v]
with model_tab:
st.subheader("Add Model")
provider = st.selectbox("Provider", available_providers)
if len(unavailable_providers) > 0:
st.caption(
f"Unavailable Providers: {', '.join(unavailable_providers)}. Please check docs page if you wish to enable them."
)
model_name = st.text_input("Model Name", "")
model_type = st.selectbox("Model Type", model_types)
if st.button("Save"):
model = Model(name=model_name, provider=provider, type=model_type)
model.save()
st.success("Saved")
st.divider()
all_models = Model.get_all()
st.subheader("Configured Models")
model_types_available = {
# "vision": False,
"text generation": False,
"embedding": False,
"text to speech": False,
"speech to text": False,
}
for model in all_models:
model_types_available[model.type] = True
with st.container(border=True):
st.markdown(f"{model.name} ({model.provider}, {model.type})")
if st.button("Delete", key=f"delete_{model.id}"):
model.delete()
st.rerun()
for model_type, available in model_types_available.items():
if not available:
st.warning(f"No models available for {model_type}")
# todo: check for each type of model
def get_selected_index(models, model_id, default=0):
"""Returns the index of the selected model in the list of models"""
if not model_id or not models:
return default
for i, model in enumerate(models):
if model.id == model_id:
return i
return default
with model_defaults_tab:
default_models = DefaultModels.load().model_dump()
all_models = Model.get_all()
text_generation_models = [
model for model in all_models if model.type == "text generation"
]
text_to_speech_models = [
model for model in all_models if model.type == "text to speech"
]
speech_to_text_models = [
model for model in all_models if model.type == "speech to text"
]
vision_models = [model for model in all_models if model.type == "vision"]
embedding_models = [model for model in all_models if model.type == "embedding"]
st.write(
"In this section, you can select the default models to be used on the various content operations done by Open Notebook. Some of these can be overriden in the different modules."
)
defs = {}
defs["default_chat_model"] = st.selectbox(
"Default Chat Model",
text_generation_models,
format_func=lambda x: x.name,
help="This model will be used for chat.",
index=get_selected_index(
text_generation_models, default_models.get("default_chat_model")
),
)
st.divider()
defs["default_transformation_model"] = st.selectbox(
"Default Transformation Model",
text_generation_models,
format_func=lambda x: x.name,
help="This model will be used for text transformations such as summaries, insights, etc.",
index=get_selected_index(
text_generation_models, default_models.get("default_transformation_model")
),
)
st.caption("You can override this model on individual transformations")
st.divider()
defs["large_context_model"] = st.selectbox(
"Large Context Model",
text_generation_models,
format_func=lambda x: x.name,
help="This model will be used for larger context generation -- recommended: Gemini",
index=get_selected_index(
text_generation_models, default_models.get("large_context_model")
),
)
st.caption("Recommended to use Gemini models for larger context processing")
st.divider()
defs["default_text_to_speech_model"] = st.selectbox(
"Default Text to Speech Model",
text_to_speech_models,
format_func=lambda x: x.name,
help="This is the default model for converting text to speech (podcasts, etc)",
index=get_selected_index(
text_to_speech_models, default_models.get("default_text_to_speech_model")
),
)
st.caption("You can override this model on different podcasts")
st.divider()
defs["default_speech_to_text_model"] = st.selectbox(
"Default Speech to Text Model",
speech_to_text_models,
format_func=lambda x: x.name,
help="This is the default model for converting speech to text (audio transcriptions, etc)",
index=get_selected_index(
speech_to_text_models, default_models.get("default_speech_to_text_model")
),
)
st.divider()
# defs["default_vision_model"] = st.selectbox(
# "Default Vision Model",
# vision_models,
# format_func=lambda x: x.name,
# help="This is the default model for vision tasks (image recognition, PDF recognition, etc)",
# index=get_selected_index(
# vision_models, default_models.get("default_vision_model")
# ),
# )
# st.divider()
defs["default_embedding_model"] = st.selectbox(
"Default Embedding Model",
embedding_models,
format_func=lambda x: x.name,
help="This is the default model for embeddings (semantic search, etc)",
index=get_selected_index(
embedding_models, default_models.get("default_embedding_model")
),
)
st.caption(
"Caution: you cannot change the embedding model once there is embeddings or they will need to be regenerated"
)
for k, v in defs.items():
defs[k] = v.id
if st.button("Save Defaults", key="save_defaults"):
DefaultModels.update(defs)
st.rerun()
# todo: return an error if a selected model is no longer supported
# todo: do this check on the app homepage as well

View file

@ -8,7 +8,7 @@ from humanize import naturaltime
from loguru import logger
from open_notebook.config import UPLOADS_FOLDER
from open_notebook.domain import Asset, Source
from open_notebook.domain.notebook import Asset, Source
from open_notebook.exceptions import UnsupportedTypeException
from open_notebook.graphs.content_processing import graph
from open_notebook.graphs.multipattern import graph as transform_graph