enable use without optional models, adds warning

This commit is contained in:
LUIS NOVO 2024-11-14 12:12:11 -03:00
parent cd3ad2e9fa
commit c8b5d422ae
12 changed files with 108 additions and 44 deletions

View file

@ -110,7 +110,6 @@ class ObjectModel(BaseModel):
def save(self) -> None:
from open_notebook.domain.models import model_manager
from open_notebook.models import EmbeddingModel
try:
self.model_validate(self.model_dump(), strict=True)
@ -120,8 +119,16 @@ class ObjectModel(BaseModel):
if self.needs_embedding():
embedding_content = self.get_embedding_content()
if embedding_content:
EMBEDDING_MODEL: EmbeddingModel = model_manager.embedding_model
data["embedding"] = EMBEDDING_MODEL.embed(embedding_content)
EMBEDDING_MODEL = model_manager.embedding_model
if not EMBEDDING_MODEL:
logger.warning(
"No embedding model found. Content will not be searchable."
)
data["embedding"] = (
EMBEDDING_MODEL.embed(embedding_content)
if EMBEDDING_MODEL
else []
)
if self.id is None:
data["created"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View file

@ -109,27 +109,30 @@ class ModelManager:
return self._default_models
@property
def speech_to_text(self, **kwargs) -> SpeechToTextModel:
def speech_to_text(self, **kwargs) -> Optional[SpeechToTextModel]:
"""Get the default speech-to-text model"""
model = self.get_default_model("speech_to_text", **kwargs)
if not isinstance(model, SpeechToTextModel):
raise TypeError(f"Expected SpeechToTextModel but got {type(model)}")
assert model is None or isinstance(
model, SpeechToTextModel
), f"Expected SpeechToTextModel but got {type(model)}"
return model
@property
def text_to_speech(self, **kwargs) -> TextToSpeechModel:
def text_to_speech(self, **kwargs) -> Optional[TextToSpeechModel]:
"""Get the default text-to-speech model"""
model = self.get_default_model("text_to_speech", **kwargs)
if not isinstance(model, TextToSpeechModel):
raise TypeError(f"Expected TextToSpeechModel but got {type(model)}")
assert model is None or isinstance(
model, TextToSpeechModel
), f"Expected TextToSpeechModel but got {type(model)}"
return model
@property
def embedding_model(self, **kwargs) -> EmbeddingModel:
def embedding_model(self, **kwargs) -> Optional[EmbeddingModel]:
"""Get the default embedding model"""
model = self.get_default_model("embedding", **kwargs)
if not isinstance(model, EmbeddingModel):
raise TypeError(f"Expected EmbeddingModel but got {type(model)}")
assert model is None or isinstance(
model, EmbeddingModel
), f"Expected EmbeddingModel but got {type(model)}"
return model
def get_default_model(self, model_type: str, **kwargs) -> ModelType:

View file

@ -261,11 +261,13 @@ class Source(ObjectModel):
def add_insight(self, insight_type: str, content: str) -> Any:
EMBEDDING_MODEL = model_manager.embedding_model
if not EMBEDDING_MODEL:
logger.warning("No embedding model found. Insight will not be searchable.")
if not insight_type or not content:
raise InvalidInputError("Insight type and content must be provided")
try:
embedding = EMBEDDING_MODEL.embed(content)
embedding = EMBEDDING_MODEL.embed(content) if EMBEDDING_MODEL else []
return repo_query(
f"""
CREATE source_insight CONTENT {{
@ -278,7 +280,7 @@ class Source(ObjectModel):
)
except Exception as e:
logger.error(f"Error adding insight to source {self.id}: {str(e)}")
raise DatabaseOperationError(e)
raise # DatabaseOperationError(e)
class Note(ObjectModel):

View file

@ -7,7 +7,7 @@ from pages.stream_app.note import add_note, note_card
from pages.stream_app.source import add_source, source_card
from pages.stream_app.utils import setup_page, setup_stream_state
setup_page("📒 Open Notebook")
setup_page("📒 Open Notebook", only_check_mandatory_models=True)
def notebook_header(current_notebook: Notebook):

View file

@ -2,7 +2,7 @@ import asyncio
import streamlit as st
from open_notebook.domain.models import DefaultModels
from open_notebook.domain.models import DefaultModels, model_manager
from open_notebook.domain.notebook import Note, Notebook, text_search, vector_search
from open_notebook.graphs.ask import graph as ask_graph
from pages.components.model_selector import model_selector
@ -76,7 +76,11 @@ with ask_tab:
selected_id=default_model,
help="This is the LLM that will be responsible for processing the final answer",
)
ask_bt = st.button("Ask")
if not model_manager.embedding_model:
st.warning(
"You can't use this feature because you have no embedding model selected. Please set one up in the Settings page."
)
ask_bt = st.button("Ask") if model_manager.embedding_model else None
placeholder = st.container()
async def stream_results():
@ -133,7 +137,13 @@ with search_tab:
st.subheader("🔍 Search")
st.caption("Search your knowledge base for specific keywords or concepts")
search_term = st.text_input("Search", "")
search_type = st.radio("Search Type", ["Text Search", "Vector Search"])
if not model_manager.embedding_model:
st.warning(
"You can't use vector search because you have no embedding model selected. Only text search will be available."
)
search_type = "Text Search"
else:
search_type = st.radio("Search Type", ["Text Search", "Vector Search"])
search_sources = st.checkbox("Search Sources", value=True)
search_notes = st.checkbox("Search Notes", value=True)
if st.button("Search"):

View file

@ -14,7 +14,7 @@ from open_notebook.plugins.podcasts import (
)
from pages.stream_app.utils import setup_page
setup_page("🎙️ Podcasts")
setup_page("🎙️ Podcasts", only_check_mandatory_models=False)
text_to_speech_models = Model.get_models_by_type("text_to_speech")

View file

@ -9,7 +9,7 @@ from open_notebook.models import MODEL_CLASS_MAP
from pages.components.model_selector import model_selector
from pages.stream_app.utils import setup_page
setup_page("⚙️ Settings")
setup_page("⚙️ Settings", only_check_mandatory_models=False, stop_on_model_error=False)
st.title("⚙️ Settings")

View file

@ -2,11 +2,16 @@ import streamlit as st
from loguru import logger
from streamlit_monaco import st_monaco # type: ignore
from open_notebook.domain.models import model_manager
from open_notebook.domain.notebook import Note
from pages.stream_app.utils import convert_source_references
def note_panel(note_id, notebook_id=None):
if not model_manager.embedding_model:
st.warning(
"Since there is no embedding model selected, your note will be saved but not searchable."
)
note: Note = Note.get(note_id)
if not note:
raise ValueError(f"Note not fonud {note_id}")
@ -17,15 +22,16 @@ def note_panel(note_id, notebook_id=None):
with t_edit:
note.title = st.text_input("Title", value=note.title)
note.content = st_monaco(
value=note.content, height="600px", language="markdown"
value=note.content, height="300px", language="markdown"
)
if st.button("Save", key=f"pn_edit_note_{note.id or 'new'}"):
b1, b2 = st.columns(2)
if b1.button("Save", key=f"pn_edit_note_{note.id or 'new'}"):
logger.debug("Editing note")
note.save()
if not note.id and notebook_id:
note.add_to_notebook(notebook_id)
st.rerun()
if st.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"):
if b2.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"):
logger.debug("Deleting note")
note.delete()
st.rerun()

View file

@ -2,13 +2,15 @@ import streamlit as st
import streamlit_scrollable_textbox as stx # type: ignore
from humanize import naturaltime
from open_notebook.domain.models import model_manager
from open_notebook.domain.notebook import Source
from open_notebook.domain.transformation import Transformation
from open_notebook.utils import surreal_clean
from pages.stream_app.utils import run_patterns
from pages.stream_app.utils import check_models, run_patterns
def source_panel(source_id: str, notebook_id=None, modal=False):
check_models(only_mandatory=False)
source: Source = Source.get(source_id)
if not source:
raise ValueError(f"Source not found: {source_id}")
@ -41,7 +43,8 @@ def source_panel(source_id: str, notebook_id=None, modal=False):
"Delete", type="primary", key=f"delete_insight_{insight.id}"
):
insight.delete()
st.rerun(scope="fragment" if modal else "app")
# st.rerun(scope="fragment" if modal else "app")
st.toast("Source deleted")
if notebook_id:
if x2.button(
"Save as Note", icon="📝", key=f"save_note_{insight.id}"
@ -66,11 +69,18 @@ def source_panel(source_id: str, notebook_id=None, modal=False):
)
st.rerun(scope="fragment" if modal else "app")
if not model_manager.embedding_model:
help = (
"No embedding model found. Please, select one on the settings page."
)
else:
help = "This will generate your embedding vectors on the database for powerful search capabilities"
if source.embedded_chunks == 0 and st.button(
"Embed vectors",
icon="🦾",
disabled=source.embedded_chunks > 0,
help="This will generate your embedding vectors on the database for powerful search capabilities",
help=help,
disabled=model_manager.embedding_model is None,
):
source.vectorize()
st.success("Embedding complete")

View file

@ -2,8 +2,8 @@ from typing import Optional
import streamlit as st
from humanize import naturaltime
from loguru import logger
from open_notebook.domain.models import model_manager
from open_notebook.domain.notebook import Note
from open_notebook.graphs.multipattern import graph as pattern_graph
from open_notebook.utils import surreal_clean
@ -14,17 +14,20 @@ from .consts import note_context_icons
@st.dialog("Write a Note", width="large")
def add_note(notebook_id):
if not model_manager.embedding_model:
st.warning(
"Since there is no embedding model selected, your note will be saved but not searchable."
)
note_title = st.text_input("Title")
note_content = st.text_area("Content")
if st.button("Save", key="add_note"):
logger.debug("Adding note")
note = Note(title=note_title, content=note_content, note_type="human")
note.save()
note.add_to_notebook(notebook_id)
st.rerun()
@st.dialog("Add a Source", width="large")
@st.dialog("Add a Note", width="large")
def note_panel_dialog(note: Optional[Note] = None, notebook_id=None):
note_panel(note_id=note.id, notebook_id=notebook_id)

View file

@ -7,6 +7,7 @@ from humanize import naturaltime
from loguru import logger
from open_notebook.config import UPLOADS_FOLDER
from open_notebook.domain.models import model_manager
from open_notebook.domain.notebook import Source
from open_notebook.domain.transformation import DefaultTransformations, Transformation
from open_notebook.exceptions import UnsupportedTypeException
@ -23,6 +24,10 @@ def source_panel_dialog(source_id, notebook_id=None):
@st.dialog("Add a Source", width="large")
def add_source(notebook_id):
if not model_manager.speech_to_text:
st.warning(
"Since there is no speech to text model selected, you can't upload audio/video files."
)
source_link = None
source_file = None
source_text = None

View file

@ -114,20 +114,30 @@ def check_migration():
st.stop()
def check_models():
def check_models(only_mandatory=True, stop_on_error=True):
default_models = model_manager.defaults
if not all(
[
default_models.default_chat_model,
default_models.default_transformation_model,
default_models.default_embedding_model,
default_models.default_speech_to_text_model,
default_models.large_context_model,
]
):
st.warning(
"You are missing some default models and the app might not work as expected. Please, select them on the settings page."
mandatory_models = [
default_models.default_chat_model,
default_models.default_transformation_model,
]
all_models = mandatory_models + [
default_models.default_embedding_model,
default_models.default_speech_to_text_model,
default_models.large_context_model,
]
if not all(mandatory_models):
st.error(
"You are missing some default models and the app will not work as expected. Please, select them on the settings page."
)
if stop_on_error:
st.stop()
if not only_mandatory:
if not all(all_models):
st.warning(
"You are missing some important optional models. The app might not work as expected. Please, select them on the settings page."
)
def handle_error(func):
@ -144,13 +154,21 @@ def handle_error(func):
return wrapper
def setup_page(title: str, layout="wide", sidebar_state="expanded"):
def setup_page(
title: str,
layout="wide",
sidebar_state="expanded",
only_check_mandatory_models=True,
stop_on_model_error=True,
):
"""Common page setup for all pages"""
st.set_page_config(
page_title=title, layout=layout, initial_sidebar_state=sidebar_state
)
check_migration()
check_models()
check_models(
only_mandatory=only_check_mandatory_models, stop_on_error=stop_on_model_error
)
version_sidebar()