remove defaultmodel from config file

This commit is contained in:
LUIS NOVO 2024-11-01 19:56:27 -03:00
parent a525fba1d2
commit feabfaed01
13 changed files with 89 additions and 80 deletions

View file

@ -3,9 +3,6 @@ import os
import yaml
from loguru import logger
from open_notebook.domain.models import DefaultModels
from open_notebook.models import get_model
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
config_path = os.path.join(project_root, "open_notebook_config.yaml")
@ -34,22 +31,3 @@ os.makedirs(UPLOADS_FOLDER, exist_ok=True)
# PODCASTS FOLDER
PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts"
os.makedirs(PODCASTS_FOLDER, exist_ok=True)
def load_default_models():
default_models = DefaultModels.load()
embedding_model = (
get_model(default_models.default_embedding_model, model_type="embedding")
if default_models.default_embedding_model
else None
)
speech_to_text_model = (
get_model(
default_models.default_speech_to_text_model, model_type="speech_to_text"
)
if default_models.default_speech_to_text_model
else None
)
return default_models, embedding_model, speech_to_text_model

View file

@ -18,8 +18,16 @@ class MigrationManager:
database=os.environ["SURREAL_DATABASE"],
encrypted=False, # Set to True if using SSL
)
self.up_migrations = [Migration.from_file("migrations/1.surrealql")]
self.down_migrations = [Migration.from_file("migrations/1_down.surrealql")]
self.up_migrations = [
Migration.from_file("migrations/1.surrealql"),
Migration.from_file("migrations/2.surrealql"),
]
self.down_migrations = [
Migration.from_file(
"migrations/1_down.surrealql",
),
Migration.from_file("migrations/2_down.surrealql"),
]
self.runner = MigrationRunner(
up_migrations=self.up_migrations,
down_migrations=self.down_migrations,

View file

@ -68,9 +68,11 @@ class ObjectModel(BaseModel):
return None
def save(self) -> None:
from open_notebook.config import load_default_models
from open_notebook.domain.models import DefaultModels
from open_notebook.models import get_model
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
DEFAULT_MODELS = DefaultModels.load()
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
try:
logger.debug(f"Validating {self.__class__.__name__}")
@ -88,7 +90,11 @@ class ObjectModel(BaseModel):
logger.debug("Creating new record")
repo_result = repo_create(self.__class__.table_name, data)
else:
data["created"] = self.created.strftime("%Y-%m-%d %H:%M:%S")
data["created"] = (
self.created.strftime("%Y-%m-%d %H:%M:%S")
if type(self.created) == datetime
else self.created
)
logger.debug(f"Updating record with id {self.id}")
repo_result = repo_update(self.id, data)
@ -148,3 +154,20 @@ class ObjectModel(BaseModel):
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
class RecordModel(BaseModel):
record_id: ClassVar[str] = "open_notebook:default_models"
@classmethod
def load(cls):
result = repo_query(f"SELECT * FROM {cls.record_id};")
if result:
result = result[0]
dm = cls(**result)
return dm
return cls()
@classmethod
def update(self, data):
repo_update(self.record_id, data)

View file

@ -1,12 +1,7 @@
from typing import ClassVar, Optional
from pydantic import BaseModel
from open_notebook.database.repository import (
repo_query,
repo_update,
)
from open_notebook.domain.base import ObjectModel
from open_notebook.database.repository import repo_query
from open_notebook.domain.base import ObjectModel, RecordModel
class Model(ObjectModel):
@ -23,7 +18,10 @@ class Model(ObjectModel):
return [Model(**model) for model in models]
class DefaultModels(BaseModel):
# todo: future: colocar um cache aqui
class DefaultModels(RecordModel):
record_id: ClassVar[str] = "open_notebook:default_models"
default_chat_model: Optional[str] = None
default_transformation_model: Optional[str] = None
large_context_model: Optional[str] = None
@ -32,15 +30,21 @@ class DefaultModels(BaseModel):
# default_vision_model: Optional[str] = None
default_embedding_model: Optional[str] = None
@classmethod
def load(self):
result = repo_query("SELECT * FROM open_notebook:default_models;")
if result:
result = result[0]
dm = DefaultModels(**result)
return dm
return DefaultModels()
@classmethod
def update(self, data):
repo_update("open_notebook:default_models", data)
# def load_default_models():
# default_models = DefaultModels.load()
# embedding_model = (
# get_model(default_models.default_embedding_model, model_type="embedding")
# if default_models.default_embedding_model
# else None
# )
# speech_to_text_model = (
# get_model(
# default_models.default_speech_to_text_model, model_type="speech_to_text"
# )
# if default_models.default_speech_to_text_model
# else None
# )
# return default_models, embedding_model, speech_to_text_model

View file

@ -4,18 +4,19 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional
from loguru import logger
from pydantic import BaseModel, Field, field_validator
from open_notebook.config import load_default_models
from open_notebook.database.repository import (
repo_create,
repo_query,
)
from open_notebook.domain.base import ObjectModel
from open_notebook.domain.models import DefaultModels
from open_notebook.exceptions import (
DatabaseOperationError,
InvalidInputError,
)
# from temp.recursive_toc import graph as toc_graph
from open_notebook.models import get_model
from open_notebook.utils import split_text, surreal_clean
@ -139,7 +140,8 @@ class Source(ObjectModel):
raise DatabaseOperationError(e)
def vectorize(self) -> None:
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
DEFAULT_MODELS = DefaultModels.load()
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
try:
if not self.full_text:
@ -190,7 +192,8 @@ class Source(ObjectModel):
raise DatabaseOperationError("Failed to search sources")
def add_insight(self, insight_type: str, content: str) -> Any:
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
DEFAULT_MODELS = DefaultModels.load()
EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model)
if not insight_type or not content:
raise InvalidInputError("Insight type and content must be provided")
@ -214,7 +217,7 @@ class Source(ObjectModel):
class Note(ObjectModel):
table_name: ClassVar[str] = "note"
title: Optional[str] = None
note_type: Optional[Literal["human", "ai"]] = "human"
note_type: Optional[Literal["human", "ai"]] = None
content: Optional[str] = None
@field_validator("content")

View file

@ -9,11 +9,12 @@ from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE, load_default_models
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
from open_notebook.domain.models import DefaultModels
from open_notebook.domain.notebook import Notebook
from open_notebook.graphs.utils import run_pattern
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
DEFAULT_MODELS = DefaultModels.load()
class ThreadState(TypedDict):

View file

@ -4,8 +4,9 @@ from math import ceil
from loguru import logger
from pydub import AudioSegment
from open_notebook.config import load_default_models
from open_notebook.domain.models import DefaultModels
from open_notebook.graphs.content_processing.state import SourceState
from open_notebook.models import get_model
# future: parallelize the transcription process
@ -72,7 +73,8 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
def extract_audio(data: SourceState):
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
DEFAULT_MODELS = DefaultModels.load()
SPEECH_TO_TEXT_MODEL = get_model(DEFAULT_MODELS.default_speech_to_text_model)
input_audio_path = data.get("file_path")
audio_files = []

View file

@ -7,10 +7,10 @@ from langchain_core.runnables import (
from langgraph.graph import END, START, StateGraph
from typing_extensions import Annotated, TypedDict
from open_notebook.config import load_default_models
from open_notebook.domain.models import DefaultModels
from open_notebook.graphs.utils import run_pattern
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
DEFAULT_MODELS = DefaultModels.load()
class PatternChainState(TypedDict):

View file

@ -1,7 +1,7 @@
from langchain.output_parsers import OutputFixingParser
from loguru import logger
from open_notebook.config import load_default_models
from open_notebook.domain.models import DefaultModels
from open_notebook.models import get_model
from open_notebook.prompter import Prompter
from open_notebook.utils import token_count
@ -18,7 +18,7 @@ def run_pattern(
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
data=state
)
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
DEFAULT_MODELS = DefaultModels.load()
tokens = token_count(str(system_prompt) + str(messages))
if tokens > 105_000:
@ -33,12 +33,12 @@ def run_pattern(
or DEFAULT_MODELS.default_chat_model
)
chain = get_model(model_id, model_type="language")
chain = get_model(model_id)
if parser:
chain = chain | parser
if output_fixing_model_id and parser:
output_fix_model = get_model(output_fixing_model_id, model_type="language")
output_fix_model = get_model(output_fixing_model_id)
chain = chain | OutputFixingParser.from_llm(
parser=parser,
llm=output_fix_model,

View file

@ -49,13 +49,12 @@ MODEL_CLASS_MAP = {
}
def get_model(model_id, model_type="language", **kwargs):
def get_model(model_id, **kwargs):
"""
Get a model instance based on model_id and type.
Args:
model_id: The ID of the model to retrieve
model_type: Type of model ('language', 'embedding', or 'speech_to_text')
**kwargs: Additional arguments to pass to the model constructor
"""
assert model_id, "Model ID cannot be empty"
@ -64,20 +63,20 @@ def get_model(model_id, model_type="language", **kwargs):
if not model:
raise ValueError(f"Model with ID {model_id} not found")
if model_type not in MODEL_CLASS_MAP:
raise ValueError(f"Invalid model type: {model_type}")
if not model.type or model.type not in MODEL_CLASS_MAP:
raise ValueError(f"Invalid model type: {model.type}")
provider_map = MODEL_CLASS_MAP[model_type]
provider_map = MODEL_CLASS_MAP[model.type]
if model.provider not in provider_map:
raise ValueError(
f"Provider {model.provider} not compatible with {model_type} models"
f"Provider {model.provider} not compatible with {model.type} models"
)
model_class = provider_map[model.provider]
model_instance = model_class(model_name=model.name, **kwargs)
# Special handling for language models that need langchain conversion
if model_type == "language":
if model.type == "language":
return model_instance.to_langchain()
return model_instance

View file

@ -1,7 +1,6 @@
import streamlit as st
from humanize import naturaltime
from open_notebook.config import load_default_models
from open_notebook.domain.notebook import Notebook
from stream_app.chat import chat_sidebar
from stream_app.note import add_note, note_card
@ -71,9 +70,6 @@ def notebook_page(current_notebook_id):
sources = current_notebook.sources
notes = current_notebook.notes
# Load the default models dynamically
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
notebook_header(current_notebook)
work_tab, chat_tab = st.columns([4, 2])

View file

@ -111,6 +111,5 @@ def chat_sidebar(session_id):
make_note_from_chat(
content=msg.content,
notebook_id=st.session_state[session_id]["notebook"].id,
type="ai",
)
st.rerun()

View file

@ -7,7 +7,7 @@ import yaml
from humanize import naturaltime
from loguru import logger
from open_notebook.config import UPLOADS_FOLDER, load_default_models
from open_notebook.config import UPLOADS_FOLDER
from open_notebook.domain.notebook import Asset, Source
from open_notebook.exceptions import UnsupportedTypeException
from open_notebook.graphs.content_processing import graph
@ -16,8 +16,6 @@ from open_notebook.utils import surreal_clean
from .consts import context_icons
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
def run_patterns(input_text, patterns):
output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns))
@ -26,18 +24,16 @@ def run_patterns(input_text, patterns):
# moved it here to replace it with the pipeline on 0.1.0
def generate_toc_and_title(source) -> "Source":
DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
try:
patterns = ["patterns/default/toc"]
result = run_patterns(source.full_text, patterns=patterns)
source.add_insight("Table of Contents", surreal_clean(result))
if not source.title:
transformations = [
patterns = [
"Based on the Table of Contents below, please provide a Title for this content, with max 15 words"
]
output = run_patterns(result["toc"], transformations=transformations)
source.title = surreal_clean(output["output"])
output = run_patterns(result, patterns=patterns)
source.title = surreal_clean(output)
source.save()
return source
except Exception as e: