From 4f4abf7098730fc75317baac977d2cf9e72e2ee4 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 18:46:43 -0300 Subject: [PATCH 01/31] change icon for ai generated notes --- stream_app/chat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stream_app/chat.py b/stream_app/chat.py index 00d8bd5..c883780 100644 --- a/stream_app/chat.py +++ b/stream_app/chat.py @@ -111,5 +111,6 @@ def chat_sidebar(session_id): make_note_from_chat( content=msg.content, notebook_id=st.session_state[session_id]["notebook"].id, + type="ai", ) st.rerun() From edf839cd1b898538ca8b6a2a14209997dae972e4 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 19:08:33 -0300 Subject: [PATCH 02/31] unused graphs --- open_notebook/graphs/doc_query.py | 47 -------------- open_notebook/graphs/pattern.py | 36 ----------- open_notebook/graphs/recursive_toc.py | 81 ----------------------- open_notebook/graphs/summary.py | 93 --------------------------- open_notebook/graphs/tools.py | 22 +++---- 5 files changed, 11 insertions(+), 268 deletions(-) delete mode 100644 open_notebook/graphs/doc_query.py delete mode 100644 open_notebook/graphs/pattern.py delete mode 100644 open_notebook/graphs/recursive_toc.py delete mode 100644 open_notebook/graphs/summary.py diff --git a/open_notebook/graphs/doc_query.py b/open_notebook/graphs/doc_query.py deleted file mode 100644 index 7f0fb31..0000000 --- a/open_notebook/graphs/doc_query.py +++ /dev/null @@ -1,47 +0,0 @@ -from langchain_core.runnables import ( - RunnableConfig, -) -from langgraph.graph import END, START, StateGraph -from typing_extensions import TypedDict - -from open_notebook.config import load_default_models -from open_notebook.domain.notebook import Note, Notebook, Source -from open_notebook.graphs.utils import run_pattern - -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - - -class DocQueryState(TypedDict): - doc_id: str - doc_content: str - question: str - answer: str - notebook: Notebook - - -def call_model(state: dict, config: RunnableConfig) -> dict: - model_id = config.get("configurable", {}).get( - "model_id", DEFAULT_MODELS.default_transformation_model - ) - return {"answer": run_pattern("doc_query", model_id, state)} - - -# todo: there is probably a better way to do this and avoid repetition -def get_content(state: DocQueryState) -> dict: - doc_id = state["doc_id"] - if "note:" in doc_id: - doc: Note = Note.get(id=doc_id) - elif "source:" in doc_id: - doc: Source = Source.get(id=doc_id) - doc_content = doc.get_context("long") if doc else None - return {"doc_content": doc_content} - - -agent_state = StateGraph(DocQueryState) -agent_state.add_node("get_content", get_content) -agent_state.add_node("agent", call_model) -agent_state.add_edge(START, "get_content") -agent_state.add_edge("get_content", "agent") -agent_state.add_edge("agent", END) - -graph = agent_state.compile() diff --git a/open_notebook/graphs/pattern.py b/open_notebook/graphs/pattern.py deleted file mode 100644 index 65e9858..0000000 --- a/open_notebook/graphs/pattern.py +++ /dev/null @@ -1,36 +0,0 @@ -from langchain_core.runnables import ( - RunnableConfig, -) -from langgraph.graph import END, START, StateGraph -from typing_extensions import TypedDict - -from open_notebook.config import load_default_models -from open_notebook.graphs.utils import run_pattern - -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - - -class PatternState(TypedDict): - input_text: str - pattern: str - output: str - - -def call_model(state: dict, config: RunnableConfig) -> dict: - model_id = config.get("configurable", {}).get( - "model_id", DEFAULT_MODELS.default_transformation_model - ) - return { - "output": run_pattern( - pattern_name=state["pattern"], - model_id=model_id, - state=state, - ) - } - - -agent_state = StateGraph(PatternState) -agent_state.add_node("agent", call_model) -agent_state.add_edge(START, "agent") -agent_state.add_edge("agent", END) -graph = agent_state.compile() diff --git a/open_notebook/graphs/recursive_toc.py b/open_notebook/graphs/recursive_toc.py deleted file mode 100644 index 02c5326..0000000 --- a/open_notebook/graphs/recursive_toc.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -from typing import List, Literal - -from langchain_core.runnables import ( - RunnableConfig, -) -from langgraph.graph import END, START, StateGraph -from typing_extensions import TypedDict - -from open_notebook.config import load_default_models -from open_notebook.graphs.utils import run_pattern -from open_notebook.utils import split_text - -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - - -class TocState(TypedDict): - chunks: List[str] - content: str - toc: str - - -def build_chunks(state: TocState) -> dict: - """ - Split the input text into chunks. - """ - return { - "chunks": split_text( - state["content"], - chunk=int(os.environ.get("SUMMARY_CHUNK_SIZE", 200000)), - overlap=int(os.environ.get("SUMMARY_CHUNK_OVERLAP", 1000)), - ) - } - - -def setup_next_chunk(state: TocState) -> dict: - """ - Move the next item in the chunk to the processing area - """ - state["content"] = state["chunks"].pop(0) - return {"chunks": state["chunks"], "content": state["content"]} - - -def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: ignore - """ - Checks whether there are more chunks to process. - """ - if len(state["chunks"]) > 0: - return "get_chunk" - return END - - -def call_model(state: TocState, config: RunnableConfig) -> dict: - model_id = config.get("configurable", {}).get( - "model_id", DEFAULT_MODELS.default_transformation_model - ) - return { - "toc": run_pattern( - pattern_name="recursive_toc", - model_id=model_id, - state=state, - ).content - } - - -agent_state = StateGraph(TocState) -agent_state.add_node("setup_chunk", build_chunks) -agent_state.add_edge(START, "setup_chunk") -agent_state.add_conditional_edges( - "setup_chunk", - chunk_condition, -) -agent_state.add_node("get_chunk", setup_next_chunk) -agent_state.add_node("agent", call_model) -agent_state.add_edge("get_chunk", "agent") -agent_state.add_conditional_edges( - "agent", - chunk_condition, -) - -graph = agent_state.compile() diff --git a/open_notebook/graphs/summary.py b/open_notebook/graphs/summary.py deleted file mode 100644 index d4e0659..0000000 --- a/open_notebook/graphs/summary.py +++ /dev/null @@ -1,93 +0,0 @@ -import os -from typing import List, Literal - -from langchain_core.output_parsers import PydanticOutputParser -from langchain_core.runnables import ( - RunnableConfig, -) -from langgraph.graph import END, START, StateGraph -from pydantic import BaseModel -from typing_extensions import TypedDict - -from open_notebook.config import load_default_models -from open_notebook.graphs.utils import run_pattern -from open_notebook.utils import split_text - -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - - -class SummaryResponse(BaseModel): - """This is schema of your response. Please provide a JSON object with the enclosed keys""" - - summary: str - topics: List[str] - title: str - - -class SummaryState(TypedDict): - chunks: List[str] - content: str - output: SummaryResponse - - -def build_chunks(state: SummaryState) -> dict: - """ - Split the input text into chunks. - """ - return { - "chunks": split_text( - state["content"], - chunk=int(os.environ.get("SUMMARY_CHUNK_SIZE", 200000)), - overlap=int(os.environ.get("SUMMARY_CHUNK_OVERLAP", 1000)), - ) - } - - -def setup_next_chunk(state: SummaryState) -> dict: - """ - Move the next item in the chunk to the processing area - """ - state["content"] = state["chunks"].pop(0) - return {"chunks": state["chunks"], "content": state["content"]} - - -def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type: ignore - """ - Checks whether there are more chunks to process. - """ - if len(state["chunks"]) > 0: - return "get_chunk" - return END - - -def call_model(state: dict, config: RunnableConfig) -> dict: - model_id = config.get("configurable", {}).get( - "model_id", DEFAULT_MODELS.default_transformation_model - ) - parser = PydanticOutputParser(pydantic_object=SummaryResponse) - return { - "output": run_pattern( - pattern_name="summarize", - model_id=model_id, - state=state, - parser=parser, - ) - } - - -agent_state = StateGraph(SummaryState) -agent_state.add_node("setup_chunk", build_chunks) -agent_state.add_edge(START, "setup_chunk") -agent_state.add_conditional_edges( - "setup_chunk", - chunk_condition, -) -agent_state.add_node("get_chunk", setup_next_chunk) -agent_state.add_node("agent", call_model) -agent_state.add_edge("get_chunk", "agent") -agent_state.add_conditional_edges( - "agent", - chunk_condition, -) - -graph = agent_state.compile() diff --git a/open_notebook/graphs/tools.py b/open_notebook/graphs/tools.py index 636e25b..96aeacc 100644 --- a/open_notebook/graphs/tools.py +++ b/open_notebook/graphs/tools.py @@ -12,15 +12,15 @@ def get_current_timestamp() -> str: return datetime.now().strftime("%Y%m%d%H%M%S") -@tool -def doc_query(doc_id: str, question: str): - """ - name: doc_query - Use this tool if you need to investigate into a particular document. - Another LLM will read the document and answer the question that you might have. - Use this when the user question cannot be answered with the content you have in context. - """ - from open_notebook.graphs.doc_query import graph +# @tool +# def doc_query(doc_id: str, question: str): +# """ +# name: doc_query +# Use this tool if you need to investigate into a particular document. +# Another LLM will read the document and answer the question that you might have. +# Use this when the user question cannot be answered with the content you have in context. +# """ +# from temp.doc_query import graph - result = graph.invoke({"doc_id": doc_id, "question": question}) - return result["answer"] +# result = graph.invoke({"doc_id": doc_id, "question": question}) +# return result["answer"] From a24faaba44462df2ab8b357ba3be06971904527d Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 19:08:47 -0300 Subject: [PATCH 03/31] moved toc to a pattern --- open_notebook/domain/notebook.py | 28 ++-------------------------- prompts/patterns/default/toc.jinja | 15 +++++++++++++++ stream_app/source.py | 24 +++++++++++++++++++++++- transformations.yaml | 5 +++++ 4 files changed, 45 insertions(+), 27 deletions(-) create mode 100644 prompts/patterns/default/toc.jinja diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index d23ee04..4321231 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -1,7 +1,6 @@ import os from typing import Any, ClassVar, Dict, List, Literal, Optional -from langchain_core.runnables.config import RunnableConfig from loguru import logger from pydantic import BaseModel, Field, field_validator @@ -15,8 +14,8 @@ from open_notebook.exceptions import ( DatabaseOperationError, InvalidInputError, ) -from open_notebook.graphs.multipattern import graph as pattern_graph -from open_notebook.graphs.recursive_toc import graph as toc_graph + +# from temp.recursive_toc import graph as toc_graph from open_notebook.utils import split_text, surreal_clean @@ -211,29 +210,6 @@ class Source(ObjectModel): logger.error(f"Error adding insight to source {self.id}: {str(e)}") raise DatabaseOperationError(e) - # todo: move this to content processing pipeline as a major graph - def generate_toc_and_title(self) -> "Source": - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - - try: - config = RunnableConfig(configurable=dict(thread_id=self.id)) - result = toc_graph.invoke({"content": self.full_text}, config=config) - self.add_insight("Table of Contents", surreal_clean(result["toc"])) - if not self.title: - transformations = [ - "Based on the Table of Contents below, please provide a Title for this content, with max 15 words" - ] - output = pattern_graph.invoke( - dict(content_stack=[result["toc"]], transformations=transformations) - ) - self.title = surreal_clean(output["output"]) - self.save() - return self - except Exception as e: - logger.error(f"Error summarizing source {self.id}: {str(e)}") - logger.exception(e) - raise DatabaseOperationError(e) - class Note(ObjectModel): table_name: ClassVar[str] = "note" diff --git a/prompts/patterns/default/toc.jinja b/prompts/patterns/default/toc.jinja new file mode 100644 index 0000000..c78f159 --- /dev/null +++ b/prompts/patterns/default/toc.jinja @@ -0,0 +1,15 @@ + +# SYSTEM ROLE +You are a content analysis assistant that reads through documents and provides a Table of Contents (ToC) to help users identify what the document covers more easily. +Your ToC should capture all major topics and transitions in the content and should mention them in the order theh appear. + +# TASK +Analyze the provided content and create a Table of Contents: +- Captures the core topics included in the text +- Gives a small description of what is covered + +# INPUT + +{{content}} + +# OUTPUT \ No newline at end of file diff --git a/stream_app/source.py b/stream_app/source.py index 63c25f8..64fd54f 100644 --- a/stream_app/source.py +++ b/stream_app/source.py @@ -24,6 +24,28 @@ def run_patterns(input_text, patterns): return output["output"] +# moved it here to replace it with the pipeline on 0.1.0 +def generate_toc_and_title(source) -> "Source": + DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + + try: + patterns = ["patterns/default/toc"] + result = run_patterns(source.full_text, patterns=patterns) + source.add_insight("Table of Contents", surreal_clean(result)) + if not source.title: + transformations = [ + "Based on the Table of Contents below, please provide a Title for this content, with max 15 words" + ] + output = run_patterns(result["toc"], transformations=transformations) + source.title = surreal_clean(output["output"]) + source.save() + return source + except Exception as e: + logger.error(f"Error summarizing source {source.id}: {str(e)}") + logger.exception(e) + raise + + @st.dialog("Source", width="large") def source_panel(source_id): source: Source = Source.get(source_id) @@ -151,7 +173,7 @@ def add_source(session_id): source.save() source.add_to_notebook(st.session_state[session_id]["notebook"].id) st.write("Summarizing...") - source.generate_toc_and_title() + generate_toc_and_title(source) except UnsupportedTypeException as e: st.warning( "This type of content is not supported yet. If you think it should be, let us know on the project Issues's page" diff --git a/transformations.yaml b/transformations.yaml index b0d3a95..435fb3c 100644 --- a/transformations.yaml +++ b/transformations.yaml @@ -16,6 +16,11 @@ source_insights: description: "Create a dense representation of the content" patterns: - patterns/default/makeitdense + - name: "Table of Contents" + insight_type: "Table of Contents" + description: "Analyzes the content and returns a ToC" + patterns: + - patterns/default/analyze_paper - name: "Analyze Paper" insight_type: "Paper Analysis" description: "Analyze the paper and provide a quick summary" From a525fba1d2b726084648f6140c07fa03a377a0e0 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 19:53:56 -0300 Subject: [PATCH 04/31] fix note_type icon --- migrations/2.surrealql | 1 + migrations/2_down.surrealql | 1 + stream_app/note.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 migrations/2.surrealql create mode 100644 migrations/2_down.surrealql diff --git a/migrations/2.surrealql b/migrations/2.surrealql new file mode 100644 index 0000000..66c1017 --- /dev/null +++ b/migrations/2.surrealql @@ -0,0 +1 @@ +DEFINE FIELD IF NOT EXISTS note_type ON TABLE note TYPE option; diff --git a/migrations/2_down.surrealql b/migrations/2_down.surrealql new file mode 100644 index 0000000..d0213a0 --- /dev/null +++ b/migrations/2_down.surrealql @@ -0,0 +1 @@ +REMOVE FIELD IF EXISTS note_type ON TABLE note; diff --git a/stream_app/note.py b/stream_app/note.py index 2cf063e..2d1f5a4 100644 --- a/stream_app/note.py +++ b/stream_app/note.py @@ -27,7 +27,7 @@ def note_panel(session_id=None, note_id=None): if note_id: note: Note = Note.get(note_id) else: - note: Note = Note() + note: Note = Note(note_type="human") t_preview, t_edit = st.tabs(["Preview", "Edit"]) with t_preview: From feabfaed0171ab008d29dcc6f898e94b29164b8c Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 19:56:27 -0300 Subject: [PATCH 05/31] remove defaultmodel from config file --- open_notebook/config.py | 22 ---------- open_notebook/database/migrate.py | 12 +++++- open_notebook/domain/base.py | 29 +++++++++++-- open_notebook/domain/models.py | 42 ++++++++++--------- open_notebook/domain/notebook.py | 11 +++-- open_notebook/graphs/chat.py | 5 ++- .../graphs/content_processing/audio.py | 6 ++- open_notebook/graphs/multipattern.py | 4 +- open_notebook/graphs/utils.py | 8 ++-- open_notebook/models/__init__.py | 13 +++--- pages/2_πŸ“’_Notebooks.py | 4 -- stream_app/chat.py | 1 - stream_app/source.py | 12 ++---- 13 files changed, 89 insertions(+), 80 deletions(-) diff --git a/open_notebook/config.py b/open_notebook/config.py index c07ab03..096850c 100644 --- a/open_notebook/config.py +++ b/open_notebook/config.py @@ -3,9 +3,6 @@ import os import yaml from loguru import logger -from open_notebook.domain.models import DefaultModels -from open_notebook.models import get_model - current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(current_dir) config_path = os.path.join(project_root, "open_notebook_config.yaml") @@ -34,22 +31,3 @@ os.makedirs(UPLOADS_FOLDER, exist_ok=True) # PODCASTS FOLDER PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts" os.makedirs(PODCASTS_FOLDER, exist_ok=True) - - -def load_default_models(): - default_models = DefaultModels.load() - embedding_model = ( - get_model(default_models.default_embedding_model, model_type="embedding") - if default_models.default_embedding_model - else None - ) - - speech_to_text_model = ( - get_model( - default_models.default_speech_to_text_model, model_type="speech_to_text" - ) - if default_models.default_speech_to_text_model - else None - ) - - return default_models, embedding_model, speech_to_text_model diff --git a/open_notebook/database/migrate.py b/open_notebook/database/migrate.py index 7d99fbd..f890091 100644 --- a/open_notebook/database/migrate.py +++ b/open_notebook/database/migrate.py @@ -18,8 +18,16 @@ class MigrationManager: database=os.environ["SURREAL_DATABASE"], encrypted=False, # Set to True if using SSL ) - self.up_migrations = [Migration.from_file("migrations/1.surrealql")] - self.down_migrations = [Migration.from_file("migrations/1_down.surrealql")] + self.up_migrations = [ + Migration.from_file("migrations/1.surrealql"), + Migration.from_file("migrations/2.surrealql"), + ] + self.down_migrations = [ + Migration.from_file( + "migrations/1_down.surrealql", + ), + Migration.from_file("migrations/2_down.surrealql"), + ] self.runner = MigrationRunner( up_migrations=self.up_migrations, down_migrations=self.down_migrations, diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index 2aa648e..76b1a32 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -68,9 +68,11 @@ class ObjectModel(BaseModel): return None def save(self) -> None: - from open_notebook.config import load_default_models + from open_notebook.domain.models import DefaultModels + from open_notebook.models import get_model - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() + EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) try: logger.debug(f"Validating {self.__class__.__name__}") @@ -88,7 +90,11 @@ class ObjectModel(BaseModel): logger.debug("Creating new record") repo_result = repo_create(self.__class__.table_name, data) else: - data["created"] = self.created.strftime("%Y-%m-%d %H:%M:%S") + data["created"] = ( + self.created.strftime("%Y-%m-%d %H:%M:%S") + if type(self.created) == datetime + else self.created + ) logger.debug(f"Updating record with id {self.id}") repo_result = repo_update(self.id, data) @@ -148,3 +154,20 @@ class ObjectModel(BaseModel): if isinstance(value, str): return datetime.fromisoformat(value.replace("Z", "+00:00")) return value + + +class RecordModel(BaseModel): + record_id: ClassVar[str] = "open_notebook:default_models" + + @classmethod + def load(cls): + result = repo_query(f"SELECT * FROM {cls.record_id};") + if result: + result = result[0] + dm = cls(**result) + return dm + return cls() + + @classmethod + def update(self, data): + repo_update(self.record_id, data) diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index 5699147..de65c53 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -1,12 +1,7 @@ from typing import ClassVar, Optional -from pydantic import BaseModel - -from open_notebook.database.repository import ( - repo_query, - repo_update, -) -from open_notebook.domain.base import ObjectModel +from open_notebook.database.repository import repo_query +from open_notebook.domain.base import ObjectModel, RecordModel class Model(ObjectModel): @@ -23,7 +18,10 @@ class Model(ObjectModel): return [Model(**model) for model in models] -class DefaultModels(BaseModel): +# todo: future: colocar um cache aqui +class DefaultModels(RecordModel): + record_id: ClassVar[str] = "open_notebook:default_models" + default_chat_model: Optional[str] = None default_transformation_model: Optional[str] = None large_context_model: Optional[str] = None @@ -32,15 +30,21 @@ class DefaultModels(BaseModel): # default_vision_model: Optional[str] = None default_embedding_model: Optional[str] = None - @classmethod - def load(self): - result = repo_query("SELECT * FROM open_notebook:default_models;") - if result: - result = result[0] - dm = DefaultModels(**result) - return dm - return DefaultModels() - @classmethod - def update(self, data): - repo_update("open_notebook:default_models", data) +# def load_default_models(): +# default_models = DefaultModels.load() +# embedding_model = ( +# get_model(default_models.default_embedding_model, model_type="embedding") +# if default_models.default_embedding_model +# else None +# ) + +# speech_to_text_model = ( +# get_model( +# default_models.default_speech_to_text_model, model_type="speech_to_text" +# ) +# if default_models.default_speech_to_text_model +# else None +# ) + +# return default_models, embedding_model, speech_to_text_model diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index 4321231..d64ee08 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -4,18 +4,19 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional from loguru import logger from pydantic import BaseModel, Field, field_validator -from open_notebook.config import load_default_models from open_notebook.database.repository import ( repo_create, repo_query, ) from open_notebook.domain.base import ObjectModel +from open_notebook.domain.models import DefaultModels from open_notebook.exceptions import ( DatabaseOperationError, InvalidInputError, ) # from temp.recursive_toc import graph as toc_graph +from open_notebook.models import get_model from open_notebook.utils import split_text, surreal_clean @@ -139,7 +140,8 @@ class Source(ObjectModel): raise DatabaseOperationError(e) def vectorize(self) -> None: - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() + EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) try: if not self.full_text: @@ -190,7 +192,8 @@ class Source(ObjectModel): raise DatabaseOperationError("Failed to search sources") def add_insight(self, insight_type: str, content: str) -> Any: - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() + EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) if not insight_type or not content: raise InvalidInputError("Insight type and content must be provided") @@ -214,7 +217,7 @@ class Source(ObjectModel): class Note(ObjectModel): table_name: ClassVar[str] = "note" title: Optional[str] = None - note_type: Optional[Literal["human", "ai"]] = "human" + note_type: Optional[Literal["human", "ai"]] = None content: Optional[str] = None @field_validator("content") diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 7735cd8..5d75151 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -9,11 +9,12 @@ from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from typing_extensions import TypedDict -from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE, load_default_models +from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE +from open_notebook.domain.models import DefaultModels from open_notebook.domain.notebook import Notebook from open_notebook.graphs.utils import run_pattern -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() +DEFAULT_MODELS = DefaultModels.load() class ThreadState(TypedDict): diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index b7c31be..ad2d7d7 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -4,8 +4,9 @@ from math import ceil from loguru import logger from pydub import AudioSegment -from open_notebook.config import load_default_models +from open_notebook.domain.models import DefaultModels from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.models import get_model # future: parallelize the transcription process @@ -72,7 +73,8 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None): def extract_audio(data: SourceState): - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() + SPEECH_TO_TEXT_MODEL = get_model(DEFAULT_MODELS.default_speech_to_text_model) input_audio_path = data.get("file_path") audio_files = [] diff --git a/open_notebook/graphs/multipattern.py b/open_notebook/graphs/multipattern.py index 75d499a..17febca 100644 --- a/open_notebook/graphs/multipattern.py +++ b/open_notebook/graphs/multipattern.py @@ -7,10 +7,10 @@ from langchain_core.runnables import ( from langgraph.graph import END, START, StateGraph from typing_extensions import Annotated, TypedDict -from open_notebook.config import load_default_models +from open_notebook.domain.models import DefaultModels from open_notebook.graphs.utils import run_pattern -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() +DEFAULT_MODELS = DefaultModels.load() class PatternChainState(TypedDict): diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 39871ee..0c0e137 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -1,7 +1,7 @@ from langchain.output_parsers import OutputFixingParser from loguru import logger -from open_notebook.config import load_default_models +from open_notebook.domain.models import DefaultModels from open_notebook.models import get_model from open_notebook.prompter import Prompter from open_notebook.utils import token_count @@ -18,7 +18,7 @@ def run_pattern( system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( data=state ) - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() + DEFAULT_MODELS = DefaultModels.load() tokens = token_count(str(system_prompt) + str(messages)) if tokens > 105_000: @@ -33,12 +33,12 @@ def run_pattern( or DEFAULT_MODELS.default_chat_model ) - chain = get_model(model_id, model_type="language") + chain = get_model(model_id) if parser: chain = chain | parser if output_fixing_model_id and parser: - output_fix_model = get_model(output_fixing_model_id, model_type="language") + output_fix_model = get_model(output_fixing_model_id) chain = chain | OutputFixingParser.from_llm( parser=parser, llm=output_fix_model, diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index 9473c9e..8cd9067 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -49,13 +49,12 @@ MODEL_CLASS_MAP = { } -def get_model(model_id, model_type="language", **kwargs): +def get_model(model_id, **kwargs): """ Get a model instance based on model_id and type. Args: model_id: The ID of the model to retrieve - model_type: Type of model ('language', 'embedding', or 'speech_to_text') **kwargs: Additional arguments to pass to the model constructor """ assert model_id, "Model ID cannot be empty" @@ -64,20 +63,20 @@ def get_model(model_id, model_type="language", **kwargs): if not model: raise ValueError(f"Model with ID {model_id} not found") - if model_type not in MODEL_CLASS_MAP: - raise ValueError(f"Invalid model type: {model_type}") + if not model.type or model.type not in MODEL_CLASS_MAP: + raise ValueError(f"Invalid model type: {model.type}") - provider_map = MODEL_CLASS_MAP[model_type] + provider_map = MODEL_CLASS_MAP[model.type] if model.provider not in provider_map: raise ValueError( - f"Provider {model.provider} not compatible with {model_type} models" + f"Provider {model.provider} not compatible with {model.type} models" ) model_class = provider_map[model.provider] model_instance = model_class(model_name=model.name, **kwargs) # Special handling for language models that need langchain conversion - if model_type == "language": + if model.type == "language": return model_instance.to_langchain() return model_instance diff --git a/pages/2_πŸ“’_Notebooks.py b/pages/2_πŸ“’_Notebooks.py index b0a6c07..e6302a3 100644 --- a/pages/2_πŸ“’_Notebooks.py +++ b/pages/2_πŸ“’_Notebooks.py @@ -1,7 +1,6 @@ import streamlit as st from humanize import naturaltime -from open_notebook.config import load_default_models from open_notebook.domain.notebook import Notebook from stream_app.chat import chat_sidebar from stream_app.note import add_note, note_card @@ -71,9 +70,6 @@ def notebook_page(current_notebook_id): sources = current_notebook.sources notes = current_notebook.notes - # Load the default models dynamically - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - notebook_header(current_notebook) work_tab, chat_tab = st.columns([4, 2]) diff --git a/stream_app/chat.py b/stream_app/chat.py index c883780..00d8bd5 100644 --- a/stream_app/chat.py +++ b/stream_app/chat.py @@ -111,6 +111,5 @@ def chat_sidebar(session_id): make_note_from_chat( content=msg.content, notebook_id=st.session_state[session_id]["notebook"].id, - type="ai", ) st.rerun() diff --git a/stream_app/source.py b/stream_app/source.py index 64fd54f..25f880b 100644 --- a/stream_app/source.py +++ b/stream_app/source.py @@ -7,7 +7,7 @@ import yaml from humanize import naturaltime from loguru import logger -from open_notebook.config import UPLOADS_FOLDER, load_default_models +from open_notebook.config import UPLOADS_FOLDER from open_notebook.domain.notebook import Asset, Source from open_notebook.exceptions import UnsupportedTypeException from open_notebook.graphs.content_processing import graph @@ -16,8 +16,6 @@ from open_notebook.utils import surreal_clean from .consts import context_icons -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - def run_patterns(input_text, patterns): output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns)) @@ -26,18 +24,16 @@ def run_patterns(input_text, patterns): # moved it here to replace it with the pipeline on 0.1.0 def generate_toc_and_title(source) -> "Source": - DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() - try: patterns = ["patterns/default/toc"] result = run_patterns(source.full_text, patterns=patterns) source.add_insight("Table of Contents", surreal_clean(result)) if not source.title: - transformations = [ + patterns = [ "Based on the Table of Contents below, please provide a Title for this content, with max 15 words" ] - output = run_patterns(result["toc"], transformations=transformations) - source.title = surreal_clean(output["output"]) + output = run_patterns(result, patterns=patterns) + source.title = surreal_clean(output) source.save() return source except Exception as e: From 3b7dd5f25f453345553f5d7882c5d81da74c7d3f Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 20:16:34 -0300 Subject: [PATCH 06/31] remove load_default_models from models file --- open_notebook/domain/models.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index de65c53..94549b1 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -29,22 +29,3 @@ class DefaultModels(RecordModel): default_speech_to_text_model: Optional[str] = None # default_vision_model: Optional[str] = None default_embedding_model: Optional[str] = None - - -# def load_default_models(): -# default_models = DefaultModels.load() -# embedding_model = ( -# get_model(default_models.default_embedding_model, model_type="embedding") -# if default_models.default_embedding_model -# else None -# ) - -# speech_to_text_model = ( -# get_model( -# default_models.default_speech_to_text_model, model_type="speech_to_text" -# ) -# if default_models.default_speech_to_text_model -# else None -# ) - -# return default_models, embedding_model, speech_to_text_model From a9ac4a6dc8e73715e0f2fa3a057ba7d0943b1ef5 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 20:37:23 -0300 Subject: [PATCH 07/31] model manager --- open_notebook/domain/base.py | 6 +- open_notebook/domain/notebook.py | 9 +- .../graphs/content_processing/audio.py | 6 +- open_notebook/graphs/utils.py | 6 +- open_notebook/models/__init__.py | 155 +++++++++++++++--- prompts/patterns/default/toc.jinja | 2 +- 6 files changed, 141 insertions(+), 43 deletions(-) diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index 76b1a32..84a57c9 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -68,11 +68,9 @@ class ObjectModel(BaseModel): return None def save(self) -> None: - from open_notebook.domain.models import DefaultModels - from open_notebook.models import get_model + from open_notebook.models import model_manager - DEFAULT_MODELS = DefaultModels.load() - EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) + EMBEDDING_MODEL = model_manager.get_default_model("embedding") try: logger.debug(f"Validating {self.__class__.__name__}") diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index d64ee08..c15e8e2 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -9,14 +9,13 @@ from open_notebook.database.repository import ( repo_query, ) from open_notebook.domain.base import ObjectModel -from open_notebook.domain.models import DefaultModels from open_notebook.exceptions import ( DatabaseOperationError, InvalidInputError, ) # from temp.recursive_toc import graph as toc_graph -from open_notebook.models import get_model +from open_notebook.models import model_manager from open_notebook.utils import split_text, surreal_clean @@ -140,8 +139,7 @@ class Source(ObjectModel): raise DatabaseOperationError(e) def vectorize(self) -> None: - DEFAULT_MODELS = DefaultModels.load() - EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) + EMBEDDING_MODEL = model_manager.get_default_model("embedding") try: if not self.full_text: @@ -192,8 +190,7 @@ class Source(ObjectModel): raise DatabaseOperationError("Failed to search sources") def add_insight(self, insight_type: str, content: str) -> Any: - DEFAULT_MODELS = DefaultModels.load() - EMBEDDING_MODEL = get_model(DEFAULT_MODELS.default_embedding_model) + EMBEDDING_MODEL = model_manager.get_default_model("embedding") if not insight_type or not content: raise InvalidInputError("Insight type and content must be provided") diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index ad2d7d7..ac81481 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -4,9 +4,8 @@ from math import ceil from loguru import logger from pydub import AudioSegment -from open_notebook.domain.models import DefaultModels from open_notebook.graphs.content_processing.state import SourceState -from open_notebook.models import get_model +from open_notebook.models import model_manager # future: parallelize the transcription process @@ -73,8 +72,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None): def extract_audio(data: SourceState): - DEFAULT_MODELS = DefaultModels.load() - SPEECH_TO_TEXT_MODEL = get_model(DEFAULT_MODELS.default_speech_to_text_model) + SPEECH_TO_TEXT_MODEL = model_manager.get_default_model("speech_to_text") input_audio_path = data.get("file_path") audio_files = [] diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 0c0e137..4b4a896 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -2,7 +2,7 @@ from langchain.output_parsers import OutputFixingParser from loguru import logger from open_notebook.domain.models import DefaultModels -from open_notebook.models import get_model +from open_notebook.models import model_manager from open_notebook.prompter import Prompter from open_notebook.utils import token_count @@ -33,12 +33,12 @@ def run_pattern( or DEFAULT_MODELS.default_chat_model ) - chain = get_model(model_id) + chain = model_manager.get_default_model("transformation") if parser: chain = chain | parser if output_fixing_model_id and parser: - output_fix_model = get_model(output_fixing_model_id) + output_fix_model = model_manager.get_model(output_fixing_model_id) chain = chain | OutputFixingParser.from_llm( parser=parser, llm=output_fix_model, diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index 8cd9067..23623bb 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -1,4 +1,6 @@ -from open_notebook.domain.models import Model +from typing import Dict, Optional + +from open_notebook.domain.models import DefaultModels, Model from open_notebook.models.embedding_models import ( GeminiEmbeddingModel, OllamaEmbeddingModel, @@ -49,34 +51,137 @@ MODEL_CLASS_MAP = { } -def get_model(model_id, **kwargs): - """ - Get a model instance based on model_id and type. +# def get_model(model_id, **kwargs): +# """ +# Get a model instance based on model_id and type. - Args: - model_id: The ID of the model to retrieve - **kwargs: Additional arguments to pass to the model constructor - """ - assert model_id, "Model ID cannot be empty" - model: Model = Model.get(model_id) +# Args: +# model_id: The ID of the model to retrieve +# **kwargs: Additional arguments to pass to the model constructor +# """ +# assert model_id, "Model ID cannot be empty" +# model: Model = Model.get(model_id) - if not model: - raise ValueError(f"Model with ID {model_id} not found") +# if not model: +# raise ValueError(f"Model with ID {model_id} not found") - if not model.type or model.type not in MODEL_CLASS_MAP: - raise ValueError(f"Invalid model type: {model.type}") +# if not model.type or model.type not in MODEL_CLASS_MAP: +# raise ValueError(f"Invalid model type: {model.type}") - provider_map = MODEL_CLASS_MAP[model.type] - if model.provider not in provider_map: - raise ValueError( - f"Provider {model.provider} not compatible with {model.type} models" - ) +# provider_map = MODEL_CLASS_MAP[model.type] +# if model.provider not in provider_map: +# raise ValueError( +# f"Provider {model.provider} not compatible with {model.type} models" +# ) - model_class = provider_map[model.provider] - model_instance = model_class(model_name=model.name, **kwargs) +# model_class = provider_map[model.provider] +# model_instance = model_class(model_name=model.name, **kwargs) - # Special handling for language models that need langchain conversion - if model.type == "language": - return model_instance.to_langchain() +# # Special handling for language models that need langchain conversion +# if model.type == "language": +# return model_instance.to_langchain() - return model_instance +# return model_instance + + +class ModelManager: + _instance = None + _model_cache: Dict[str, object] = {} + _default_models: Optional[DefaultModels] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ModelManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, "_initialized"): + self._initialized = True + self.refresh_defaults() + + def refresh_defaults(self): + """Refresh the default models from the database""" + self._default_models = DefaultModels.load() + + @property + def defaults(self) -> DefaultModels: + """Get the default models configuration""" + if not self._default_models: + self.refresh_defaults() + return self._default_models + + def get_model(self, model_id: str, **kwargs) -> object: + """ + Get a model instance based on model_id. Uses caching to avoid recreating instances. + + Args: + model_id: The ID of the model to retrieve + **kwargs: Additional arguments to pass to the model constructor + """ + cache_key = f"{model_id}:{str(kwargs)}" + + if cache_key in self._model_cache: + return self._model_cache[cache_key] + + assert model_id, "Model ID cannot be empty" + model: Model = Model.get(model_id) + + if not model: + raise ValueError(f"Model with ID {model_id} not found") + + if not model.type or model.type not in MODEL_CLASS_MAP: + raise ValueError(f"Invalid model type: {model.type}") + + provider_map = MODEL_CLASS_MAP[model.type] + if model.provider not in provider_map: + raise ValueError( + f"Provider {model.provider} not compatible with {model.type} models" + ) + + model_class = provider_map[model.provider] + model_instance = model_class(model_name=model.name, **kwargs) + + # Special handling for language models that need langchain conversion + if model.type == "language": + model_instance = model_instance.to_langchain() + + self._model_cache[cache_key] = model_instance + return model_instance + + def get_default_model(self, model_type: str, **kwargs) -> object: + """ + Get the default model for a specific type. + + Args: + model_type: The type of model to retrieve (e.g., 'chat', 'embedding', etc.) + **kwargs: Additional arguments to pass to the model constructor + """ + model_id = None + + if model_type == "chat": + model_id = self.defaults.default_chat_model + elif model_type == "transformation": + model_id = ( + self.defaults.default_transformation_model + or self.defaults.default_chat_model + ) + elif model_type == "embedding": + model_id = self.defaults.default_embedding_model + elif model_type == "text_to_speech": + model_id = self.defaults.default_text_to_speech_model + elif model_type == "speech_to_text": + model_id = self.defaults.default_speech_to_text_model + elif model_type == "large_context": + model_id = self.defaults.large_context_model + + if not model_id: + raise ValueError(f"No default model configured for type: {model_type}") + + return self.get_model(model_id, **kwargs) + + def clear_cache(self): + """Clear the model cache""" + self._model_cache.clear() + + +model_manager = ModelManager() diff --git a/prompts/patterns/default/toc.jinja b/prompts/patterns/default/toc.jinja index c78f159..23b84f0 100644 --- a/prompts/patterns/default/toc.jinja +++ b/prompts/patterns/default/toc.jinja @@ -10,6 +10,6 @@ Analyze the provided content and create a Table of Contents: # INPUT -{{content}} +{{input_text}} # OUTPUT \ No newline at end of file From 8734b1803cf86a95da1bd708342c84cfc79b352b Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 20:37:38 -0300 Subject: [PATCH 08/31] mm --- open_notebook/models/__init__.py | 33 -------------------------------- 1 file changed, 33 deletions(-) diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index 23623bb..5b4cdea 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -51,39 +51,6 @@ MODEL_CLASS_MAP = { } -# def get_model(model_id, **kwargs): -# """ -# Get a model instance based on model_id and type. - -# Args: -# model_id: The ID of the model to retrieve -# **kwargs: Additional arguments to pass to the model constructor -# """ -# assert model_id, "Model ID cannot be empty" -# model: Model = Model.get(model_id) - -# if not model: -# raise ValueError(f"Model with ID {model_id} not found") - -# if not model.type or model.type not in MODEL_CLASS_MAP: -# raise ValueError(f"Invalid model type: {model.type}") - -# provider_map = MODEL_CLASS_MAP[model.type] -# if model.provider not in provider_map: -# raise ValueError( -# f"Provider {model.provider} not compatible with {model.type} models" -# ) - -# model_class = provider_map[model.provider] -# model_instance = model_class(model_name=model.name, **kwargs) - -# # Special handling for language models that need langchain conversion -# if model.type == "language": -# return model_instance.to_langchain() - -# return model_instance - - class ModelManager: _instance = None _model_cache: Dict[str, object] = {} From 3b262a63f481b440a018f262800529fcf5c5f840 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 21:11:23 -0300 Subject: [PATCH 09/31] better model mgmt --- open_notebook/graphs/chat.py | 18 ++++-------- open_notebook/graphs/multipattern.py | 8 +---- open_notebook/graphs/utils.py | 44 +++++++++++++++++----------- prompts/doc_query.jinja | 26 ---------------- prompts/recursive_toc.jinja | 24 --------------- prompts/summarize.jinja | 33 --------------------- 6 files changed, 34 insertions(+), 119 deletions(-) delete mode 100644 prompts/doc_query.jinja delete mode 100644 prompts/recursive_toc.jinja delete mode 100644 prompts/summarize.jinja diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 5d75151..5d0939f 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -10,11 +10,9 @@ from langgraph.graph.message import add_messages from typing_extensions import TypedDict from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE -from open_notebook.domain.models import DefaultModels from open_notebook.domain.notebook import Notebook -from open_notebook.graphs.utils import run_pattern - -DEFAULT_MODELS = DefaultModels.load() +from open_notebook.graphs.utils import provision_model +from open_notebook.prompter import Prompter class ThreadState(TypedDict): @@ -25,15 +23,11 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: - model_id = config.get("configurable", {}).get( - "model_id", DEFAULT_MODELS.default_chat_model - ) - ai_message = run_pattern( - "chat", - model_id, - messages=state["messages"], - state=state, + system_prompt = Prompter(prompt_template="chat").render(data=state) + model = provision_model( + str(system_prompt) + str(state.get("messages", [])), config, "chat" ) + ai_message = model.invoke([system_prompt] + state.get("messages", [])) return {"messages": ai_message} diff --git a/open_notebook/graphs/multipattern.py b/open_notebook/graphs/multipattern.py index 17febca..9b95638 100644 --- a/open_notebook/graphs/multipattern.py +++ b/open_notebook/graphs/multipattern.py @@ -7,11 +7,8 @@ from langchain_core.runnables import ( from langgraph.graph import END, START, StateGraph from typing_extensions import Annotated, TypedDict -from open_notebook.domain.models import DefaultModels from open_notebook.graphs.utils import run_pattern -DEFAULT_MODELS = DefaultModels.load() - class PatternChainState(TypedDict): content_stack: Annotated[Sequence[str], operator.add] @@ -20,9 +17,6 @@ class PatternChainState(TypedDict): def call_model(state: dict, config: RunnableConfig) -> dict: - model_id = config.get("configurable", {}).get( - "model_id", DEFAULT_MODELS.default_transformation_model - ) patterns = state["patterns"] current_transformation = patterns.pop(0) if current_transformation.startswith("patterns/"): @@ -36,7 +30,7 @@ def call_model(state: dict, config: RunnableConfig) -> dict: transformation_result = run_pattern( pattern_name=current_transformation, - model_id=model_id, + config=config, state=input_args, ) return { diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 4b4a896..ab78147 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -1,39 +1,48 @@ from langchain.output_parsers import OutputFixingParser +from langchain_core.messages import AIMessage from loguru import logger -from open_notebook.domain.models import DefaultModels from open_notebook.models import model_manager from open_notebook.prompter import Prompter from open_notebook.utils import token_count +def provision_model(content, config, default_type): + """ + Returns the best model to use based on the context size and on whether there is a specific model being requested in Config. + If context > 105_000, returns the large_context_model + If model_id is specified in Config, returns that model + Otherwise, returns the default model for the given type + """ + tokens = token_count(content) + + if tokens > 105_000: + logger.debug( + f"Using large context model because the content has {tokens} tokens" + ) + return model_manager.get_default_model("large_context") + elif config.get("configurable", {}).get("model_id"): + return model_manager.get_model(config.get("configurable", {}).get("model_id")) + else: + return model_manager.get_default_model(default_type) + + +# todo: turn into a graph def run_pattern( pattern_name: str, - model_id=None, + config, messages=[], state: dict = {}, parser=None, output_fixing_model_id=None, -) -> dict: +) -> AIMessage: system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( data=state ) - DEFAULT_MODELS = DefaultModels.load() - tokens = token_count(str(system_prompt) + str(messages)) - - if tokens > 105_000: - model_id = DEFAULT_MODELS.large_context_model - logger.debug( - f"Using large context model ({model_id}) because the content has {tokens} tokens" - ) - - model_id = ( - model_id - or DEFAULT_MODELS.default_transformation_model - or DEFAULT_MODELS.default_chat_model + chain = provision_model( + str(system_prompt) + str(messages), config, "transformation" ) - chain = model_manager.get_default_model("transformation") if parser: chain = chain | parser @@ -44,6 +53,7 @@ def run_pattern( llm=output_fix_model, ) + # todo: precisa deste if? if len(messages) > 0: response = chain.invoke([system_prompt] + messages) else: diff --git a/prompts/doc_query.jinja b/prompts/doc_query.jinja deleted file mode 100644 index 4212bff..0000000 --- a/prompts/doc_query.jinja +++ /dev/null @@ -1,26 +0,0 @@ - -# BACKGROUND - -Your are a cognitive assistant that helps me study and research. - -# OUR WORKING FRAMEWORK - -You have access to some information about the project I am working on -as well as the content of a specific item I am interested about. - -Your goal is to respond to the question using purely the content in your CONTEXT. - -If the content in CONTEXT is not enough to answer the question, do not make up any information and just reply that you can't answer that. -Kindly tell the user what sort of things you'd be able to talk about. - -# PROJECT INFO - -{{ notebook }} - -# CONTENT - -{{ doc_content }} - -# QUESTION - -{{ question}} \ No newline at end of file diff --git a/prompts/recursive_toc.jinja b/prompts/recursive_toc.jinja deleted file mode 100644 index b92512b..0000000 --- a/prompts/recursive_toc.jinja +++ /dev/null @@ -1,24 +0,0 @@ - -# SYSTEM ROLE -You are a content analysis assistant that reads through documents and provides a Table of Contents (ToC) to help users identify what the document covers more easily. -Your ToC should capture all major topics and transitions in the content and should mention them in the order theh appear. - -# TASK -Analyze the provided content and create a Table of Contents: -- Captures the core topics included in the text -- Gives a small description of what is covered - -# INSTRUCTIONS FOR LARGE DOCUMENTS - -If you see a PREVIOUS TOC section below, it means that this request is a continuation of a previous request. Most likely to handle context length issues. -Every time, you should replace the previous toc with the new one, and append the new content to the previous content. - -{% if toc %} -# PREVIOUS TOC - -{{toc}} -{% endif %} - -# CONTENT - -{{content}} diff --git a/prompts/summarize.jinja b/prompts/summarize.jinja deleted file mode 100644 index f8b65ab..0000000 --- a/prompts/summarize.jinja +++ /dev/null @@ -1,33 +0,0 @@ - -# SYSTEM ROLE -You are a content summarization assistant that creates dense, information-rich summaries optimized for machine understanding. Your summaries should capture key concepts with minimal words while maintaining complete, clear sentences. - -# TASK -Analyze the provided content and create a summary that: -- Captures the core concepts and key information -- Uses clear, direct language -- Maintains context from any previous summaries -- Includes relevant topics/tags -- Creates an appropriate title - -# OUTPUT SCHEMA -{'summary': {'type': 'string'}, - 'topics': {'items': {'type': 'string'}, 'type': 'array'}, - 'title': {'type': 'string'}} - -# OUTPUT EXAMPLE -{ - "title": "The title of the content", - "topics": ["topic1", "topic2"], - "summary": "The summary of the content" -} - -# CONTENT - -{{content}} - -{% if summary %} -# PREVIOUS SUMMARY - -{{summary}} -{% endif %} From fcd883f393c9c528db0d5e1d8d28e807b4bdc710 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 21:32:13 -0300 Subject: [PATCH 10/31] add common checks to all pages --- app_home.py | 42 +--------------------------------------- pages/2_πŸ“’_Notebooks.py | 4 ++-- pages/3_πŸ”_Search.py | 8 ++++---- pages/5_πŸŽ™οΈ_Podcasts.py | 4 ++-- pages/7_βš™οΈ_Settings.py | 4 ++-- pages/8_πŸ›_Playground.py | 4 ++-- 6 files changed, 13 insertions(+), 53 deletions(-) diff --git a/app_home.py b/app_home.py index 9840c6c..4badc36 100644 --- a/app_home.py +++ b/app_home.py @@ -1,43 +1,3 @@ import streamlit as st -from open_notebook.database.migrate import MigrationManager - -# from open_notebook.config import DEFAULT_MODELS -from open_notebook.domain.models import DefaultModels -from stream_app.utils import version_sidebar - -default_models = DefaultModels.load() - -version_sidebar() -mm = MigrationManager() -if mm.needs_migration: - st.warning("The Open Notebook database needs a migration to run properly.") - if st.button("Run Migration"): - mm.run_migration_up() - st.success("Migration successful") - st.rerun() -elif ( - not default_models.default_chat_model - or not default_models.default_transformation_model -): - st.warning( - "You don't have default chat and transformation models selected. Please, select them on the settings page." - ) -elif not default_models.default_embedding_model: - st.warning( - "You don't have a default embedding model selected. Vector search will not be possible and your assistant will be less able to answer your queries. Please, select one on the settings page." - ) -elif not default_models.default_speech_to_text_model: - st.warning( - "You don't have a default speech to text model selected. Your assistant will not be able to transcribe audio. Please, select one on the settings page." - ) -elif not default_models.default_text_to_speech_model: - st.warning( - "You don't have a default text to speech model selected. Your assistant will not be able to generate audio and podcasts. Please, select one on the settings page." - ) -elif not default_models.large_context_model: - st.warning( - "You don't have a large context model selected. Your assistant will not be able to process large documents. Please, select one on the settings page." - ) -else: - st.switch_page("pages/2_πŸ“’_Notebooks.py") +st.switch_page("pages/2_πŸ“’_Notebooks.py") diff --git a/pages/2_πŸ“’_Notebooks.py b/pages/2_πŸ“’_Notebooks.py index e6302a3..2193425 100644 --- a/pages/2_πŸ“’_Notebooks.py +++ b/pages/2_πŸ“’_Notebooks.py @@ -5,13 +5,13 @@ from open_notebook.domain.notebook import Notebook from stream_app.chat import chat_sidebar from stream_app.note import add_note, note_card from stream_app.source import add_source, source_card -from stream_app.utils import setup_stream_state, version_sidebar +from stream_app.utils import page_commons, setup_stream_state st.set_page_config( layout="wide", page_title="πŸ“’ Open Notebook", initial_sidebar_state="expanded" ) -version_sidebar() +page_commons() def notebook_header(current_notebook): diff --git a/pages/3_πŸ”_Search.py b/pages/3_πŸ”_Search.py index cdfa4fc..b092af3 100644 --- a/pages/3_πŸ”_Search.py +++ b/pages/3_πŸ”_Search.py @@ -1,17 +1,17 @@ import streamlit as st -from open_notebook.config import load_default_models from open_notebook.domain.notebook import text_search, vector_search +from open_notebook.models import model_manager from stream_app.note import note_list_item from stream_app.source import source_list_item -from stream_app.utils import version_sidebar +from stream_app.utils import page_commons st.set_page_config( layout="wide", page_title="πŸ” Search", initial_sidebar_state="expanded" ) -version_sidebar() +page_commons() -DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models() +EMBEDDING_MODEL = model_manager.get_default_model("embedding") # search_tab, ask_tab = st.tabs(["Search", "Ask"]) # notebooks = Notebook.get_all() diff --git a/pages/5_πŸŽ™οΈ_Podcasts.py b/pages/5_πŸŽ™οΈ_Podcasts.py index db1273c..baec9c0 100644 --- a/pages/5_πŸŽ™οΈ_Podcasts.py +++ b/pages/5_πŸŽ™οΈ_Podcasts.py @@ -12,13 +12,13 @@ from open_notebook.plugins.podcasts import ( engagement_techniques, participant_roles, ) -from stream_app.utils import version_sidebar +from stream_app.utils import page_commons st.set_page_config( layout="wide", page_title="πŸŽ™οΈ Podcasts", initial_sidebar_state="expanded" ) -version_sidebar() +page_commons() text_to_speech_models = Model.get_models_by_type("text_to_speech") diff --git a/pages/7_βš™οΈ_Settings.py b/pages/7_βš™οΈ_Settings.py index 209814c..6bfc899 100644 --- a/pages/7_βš™οΈ_Settings.py +++ b/pages/7_βš™οΈ_Settings.py @@ -4,12 +4,12 @@ import streamlit as st from open_notebook.domain.models import DefaultModels, Model from open_notebook.models import MODEL_CLASS_MAP -from stream_app.utils import version_sidebar +from stream_app.utils import page_commons st.set_page_config( layout="wide", page_title="βš™οΈ Settings", initial_sidebar_state="expanded" ) -version_sidebar() +page_commons() st.title("βš™οΈ Settings") diff --git a/pages/8_πŸ›_Playground.py b/pages/8_πŸ›_Playground.py index 53de8f9..b7151a5 100644 --- a/pages/8_πŸ›_Playground.py +++ b/pages/8_πŸ›_Playground.py @@ -3,12 +3,12 @@ import yaml from open_notebook.domain.models import Model from open_notebook.graphs.multipattern import graph as pattern_graph -from stream_app.utils import version_sidebar +from stream_app.utils import page_commons st.set_page_config( layout="wide", page_title="πŸ› Playground", initial_sidebar_state="expanded" ) -version_sidebar() +page_commons() st.title("πŸ› Playground") with open("transformations.yaml", "r") as file: From 15048b08392b4da3489bb4e69cbe1ad1730e9a7b Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 21:32:40 -0300 Subject: [PATCH 11/31] simplify model provisioning --- open_notebook/graphs/chat.py | 7 ++--- open_notebook/graphs/utils.py | 11 ++----- stream_app/utils.py | 54 +++++++++++++++++++++++++++++++++-- 3 files changed, 58 insertions(+), 14 deletions(-) diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 5d0939f..c87af21 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -24,10 +24,9 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="chat").render(data=state) - model = provision_model( - str(system_prompt) + str(state.get("messages", [])), config, "chat" - ) - ai_message = model.invoke([system_prompt] + state.get("messages", [])) + payload = [system_prompt] + state.get("messages", []) + model = provision_model(str(payload), config, "chat") + ai_message = model.invoke(payload, []) return {"messages": ai_message} diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index ab78147..5d8339c 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -39,9 +39,8 @@ def run_pattern( system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( data=state ) - chain = provision_model( - str(system_prompt) + str(messages), config, "transformation" - ) + payload = [system_prompt] + messages + chain = provision_model(str(payload), config, "transformation") if parser: chain = chain | parser @@ -53,10 +52,6 @@ def run_pattern( llm=output_fix_model, ) - # todo: precisa deste if? - if len(messages) > 0: - response = chain.invoke([system_prompt] + messages) - else: - response = chain.invoke(system_prompt) + response = chain.invoke(payload) return response diff --git a/stream_app/utils.py b/stream_app/utils.py index 55d9db1..911305c 100644 --- a/stream_app/utils.py +++ b/stream_app/utils.py @@ -1,6 +1,8 @@ import streamlit as st +from open_notebook.database.migrate import MigrationManager from open_notebook.graphs.chat import ThreadState, graph +from open_notebook.models import model_manager from open_notebook.utils import ( compare_versions, get_installed_version, @@ -40,9 +42,57 @@ def setup_stream_state(session_id) -> None: existing_state = graph.get_state({"configurable": {"thread_id": session_id}}).values if len(existing_state.keys()) == 0: st.session_state[session_id] = ThreadState( - messages=[], context=None, notebook=None, context_config={}, response=None + messages=[], context=None, notebook=None, context_config={} ) else: st.session_state[session_id] = existing_state st.session_state["active_session"] = session_id - st.session_state["active_session"] = session_id + + +def check_migration(): + mm = MigrationManager() + if mm.needs_migration: + st.warning("The Open Notebook database needs a migration to run properly.") + if st.button("Run Migration"): + mm.run_migration_up() + st.success("Migration successful") + st.rerun() + st.stop() + + +def check_models(): + default_models = model_manager.defaults + if ( + not default_models.default_chat_model + or not default_models.default_transformation_model + ): + st.warning( + "You don't have default chat and transformation models selected. Please, select them on the settings page." + ) + st.stop() + elif not default_models.default_embedding_model: + st.warning( + "You don't have a default embedding model selected. Vector search will not be possible and your assistant will be less able to answer your queries. Please, select one on the settings page." + ) + st.stop() + elif not default_models.default_speech_to_text_model: + st.warning( + "You don't have a default speech to text model selected. Your assistant will not be able to transcribe audio. Please, select one on the settings page." + ) + st.stop() + elif not default_models.default_text_to_speech_model: + st.warning( + "You don't have a default text to speech model selected. Your assistant will not be able to generate audio and podcasts. Please, select one on the settings page." + ) + st.stop() + elif not default_models.large_context_model: + st.warning( + "You don't have a large context model selected. Your assistant will not be able to process large documents. Please, select one on the settings page." + ) + st.stop() + + +def page_commons(): + version_sidebar() + check_migration() + check_models() From c15982af3f1a3ce485b8f565805866e57b0c4b25 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 21:33:07 -0300 Subject: [PATCH 12/31] cleanup --- stream_app/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stream_app/utils.py b/stream_app/utils.py index 911305c..fc3800c 100644 --- a/stream_app/utils.py +++ b/stream_app/utils.py @@ -13,9 +13,7 @@ from open_notebook.utils import ( def version_sidebar(): with st.sidebar: try: - current_version = get_installed_version( - "open-notebook" - ) # Note the hyphen instead of underscore + current_version = get_installed_version("open-notebook") except Exception: # Fallback to reading directly from pyproject.toml import tomli From 3b36caceb9befac998057698793884fd0762df30 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 21:40:28 -0300 Subject: [PATCH 13/31] typing fixes --- open_notebook/domain/notebook.py | 4 +++- open_notebook/models/__init__.py | 20 +++++++++++++++----- pages/3_πŸ”_Search.py | 4 ++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index c15e8e2..207f5ca 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -265,7 +265,9 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr raise DatabaseOperationError("Failed to perform text search") -def vector_search(keyword: str, results: int, source: bool = True, note: bool = True): +def vector_search( + keyword: List[float], results: int, source: bool = True, note: bool = True +): if not keyword: raise InvalidInputError("Search keyword cannot be empty") try: diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index 5b4cdea..679f262 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -1,7 +1,8 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Union from open_notebook.domain.models import DefaultModels, Model from open_notebook.models.embedding_models import ( + EmbeddingModel, GeminiEmbeddingModel, OllamaEmbeddingModel, OpenAIEmbeddingModel, @@ -10,6 +11,7 @@ from open_notebook.models.embedding_models import ( from open_notebook.models.llms import ( AnthropicLanguageModel, GeminiLanguageModel, + LanguageModel, LiteLLMLanguageModel, OllamaLanguageModel, OpenAILanguageModel, @@ -17,10 +19,14 @@ from open_notebook.models.llms import ( VertexAILanguageModel, VertexAnthropicLanguageModel, ) -from open_notebook.models.speech_to_text_models import OpenAISpeechToTextModel +from open_notebook.models.speech_to_text_models import ( + OpenAISpeechToTextModel, + SpeechToTextModel, +) from open_notebook.models.text_to_speech_models import ( ElevenLabsTextToSpeechModel, OpenAITextToSpeechModel, + TextToSpeechModel, ) # Unified model class map with type information @@ -77,7 +83,9 @@ class ModelManager: self.refresh_defaults() return self._default_models - def get_model(self, model_id: str, **kwargs) -> object: + def get_model( + self, model_id: str, **kwargs + ) -> Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]: """ Get a model instance based on model_id. Uses caching to avoid recreating instances. @@ -110,12 +118,14 @@ class ModelManager: # Special handling for language models that need langchain conversion if model.type == "language": - model_instance = model_instance.to_langchain() + model_instance = model_instance self._model_cache[cache_key] = model_instance return model_instance - def get_default_model(self, model_type: str, **kwargs) -> object: + def get_default_model( + self, model_type: str, **kwargs + ) -> Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]: """ Get the default model for a specific type. diff --git a/pages/3_πŸ”_Search.py b/pages/3_πŸ”_Search.py index b092af3..1d7dd4d 100644 --- a/pages/3_πŸ”_Search.py +++ b/pages/3_πŸ”_Search.py @@ -1,7 +1,7 @@ import streamlit as st from open_notebook.domain.notebook import text_search, vector_search -from open_notebook.models import model_manager +from open_notebook.models import EmbeddingModel, model_manager from stream_app.note import note_list_item from stream_app.source import source_list_item from stream_app.utils import page_commons @@ -11,7 +11,7 @@ st.set_page_config( ) page_commons() -EMBEDDING_MODEL = model_manager.get_default_model("embedding") +EMBEDDING_MODEL: EmbeddingModel = model_manager.get_default_model("embedding") # search_tab, ask_tab = st.tabs(["Search", "Ask"]) # notebooks = Notebook.get_all() From b616d1ad173bd90e5e3c214eab8e3a19fe263a63 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 22:01:25 -0300 Subject: [PATCH 14/31] page location refactor --- pages/2_πŸ“’_Notebooks.py | 14 +++++------- pages/3_πŸ”_Search.py | 11 ++++------ pages/5_πŸŽ™οΈ_Podcasts.py | 8 ++----- pages/7_βš™οΈ_Settings.py | 7 ++---- pages/8_πŸ›_Playground.py | 7 ++---- {stream_app => pages/stream_app}/__init__.py | 0 {stream_app => pages/stream_app}/chat.py | 3 ++- {stream_app => pages/stream_app}/consts.py | 0 {stream_app => pages/stream_app}/note.py | 1 - {stream_app => pages/stream_app}/source.py | 0 {stream_app => pages/stream_app}/utils.py | 23 ++++++++++++++++++-- 11 files changed, 38 insertions(+), 36 deletions(-) rename {stream_app => pages/stream_app}/__init__.py (100%) rename {stream_app => pages/stream_app}/chat.py (98%) rename {stream_app => pages/stream_app}/consts.py (100%) rename {stream_app => pages/stream_app}/note.py (99%) rename {stream_app => pages/stream_app}/source.py (100%) rename {stream_app => pages/stream_app}/utils.py (85%) diff --git a/pages/2_πŸ“’_Notebooks.py b/pages/2_πŸ“’_Notebooks.py index 2193425..f9eb191 100644 --- a/pages/2_πŸ“’_Notebooks.py +++ b/pages/2_πŸ“’_Notebooks.py @@ -2,16 +2,12 @@ import streamlit as st from humanize import naturaltime from open_notebook.domain.notebook import Notebook -from stream_app.chat import chat_sidebar -from stream_app.note import add_note, note_card -from stream_app.source import add_source, source_card -from stream_app.utils import page_commons, setup_stream_state +from pages.stream_app.chat import chat_sidebar +from pages.stream_app.note import add_note, note_card +from pages.stream_app.source import add_source, source_card +from pages.stream_app.utils import setup_page, setup_stream_state -st.set_page_config( - layout="wide", page_title="πŸ“’ Open Notebook", initial_sidebar_state="expanded" -) - -page_commons() +setup_page("πŸ“’ Open Notebook") def notebook_header(current_notebook): diff --git a/pages/3_πŸ”_Search.py b/pages/3_πŸ”_Search.py index 1d7dd4d..c0a7b84 100644 --- a/pages/3_πŸ”_Search.py +++ b/pages/3_πŸ”_Search.py @@ -2,14 +2,11 @@ import streamlit as st from open_notebook.domain.notebook import text_search, vector_search from open_notebook.models import EmbeddingModel, model_manager -from stream_app.note import note_list_item -from stream_app.source import source_list_item -from stream_app.utils import page_commons +from pages.stream_app.note import note_list_item +from pages.stream_app.source import source_list_item +from pages.stream_app.utils import setup_page -st.set_page_config( - layout="wide", page_title="πŸ” Search", initial_sidebar_state="expanded" -) -page_commons() +setup_page("πŸ” Search") EMBEDDING_MODEL: EmbeddingModel = model_manager.get_default_model("embedding") diff --git a/pages/5_πŸŽ™οΈ_Podcasts.py b/pages/5_πŸŽ™οΈ_Podcasts.py index baec9c0..d15559e 100644 --- a/pages/5_πŸŽ™οΈ_Podcasts.py +++ b/pages/5_πŸŽ™οΈ_Podcasts.py @@ -12,13 +12,9 @@ from open_notebook.plugins.podcasts import ( engagement_techniques, participant_roles, ) -from stream_app.utils import page_commons +from pages.stream_app.utils import setup_page -st.set_page_config( - layout="wide", page_title="πŸŽ™οΈ Podcasts", initial_sidebar_state="expanded" -) - -page_commons() +setup_page("πŸŽ™οΈ Podcasts") text_to_speech_models = Model.get_models_by_type("text_to_speech") diff --git a/pages/7_βš™οΈ_Settings.py b/pages/7_βš™οΈ_Settings.py index 6bfc899..0d93594 100644 --- a/pages/7_βš™οΈ_Settings.py +++ b/pages/7_βš™οΈ_Settings.py @@ -4,12 +4,9 @@ import streamlit as st from open_notebook.domain.models import DefaultModels, Model from open_notebook.models import MODEL_CLASS_MAP -from stream_app.utils import page_commons +from pages.stream_app.utils import setup_page -st.set_page_config( - layout="wide", page_title="βš™οΈ Settings", initial_sidebar_state="expanded" -) -page_commons() +setup_page("βš™οΈ Settings") st.title("βš™οΈ Settings") diff --git a/pages/8_πŸ›_Playground.py b/pages/8_πŸ›_Playground.py index b7151a5..5bcac7a 100644 --- a/pages/8_πŸ›_Playground.py +++ b/pages/8_πŸ›_Playground.py @@ -3,12 +3,9 @@ import yaml from open_notebook.domain.models import Model from open_notebook.graphs.multipattern import graph as pattern_graph -from stream_app.utils import page_commons +from pages.stream_app.utils import setup_page -st.set_page_config( - layout="wide", page_title="πŸ› Playground", initial_sidebar_state="expanded" -) -page_commons() +setup_page("πŸ› Playground") st.title("πŸ› Playground") with open("transformations.yaml", "r") as file: diff --git a/stream_app/__init__.py b/pages/stream_app/__init__.py similarity index 100% rename from stream_app/__init__.py rename to pages/stream_app/__init__.py diff --git a/stream_app/chat.py b/pages/stream_app/chat.py similarity index 98% rename from stream_app/chat.py rename to pages/stream_app/chat.py index 00d8bd5..c3c2426 100644 --- a/stream_app/chat.py +++ b/pages/stream_app/chat.py @@ -5,7 +5,8 @@ from open_notebook.domain.notebook import Note, Source from open_notebook.graphs.chat import graph as chat_graph from open_notebook.plugins.podcasts import PodcastConfig from open_notebook.utils import token_count -from stream_app.note import make_note_from_chat + +from .note import make_note_from_chat # todo: build a smarter, more robust context manager function diff --git a/stream_app/consts.py b/pages/stream_app/consts.py similarity index 100% rename from stream_app/consts.py rename to pages/stream_app/consts.py diff --git a/stream_app/note.py b/pages/stream_app/note.py similarity index 99% rename from stream_app/note.py rename to pages/stream_app/note.py index 2d1f5a4..f22e29c 100644 --- a/stream_app/note.py +++ b/pages/stream_app/note.py @@ -94,7 +94,6 @@ def note_card(session_id, note): def note_list_item(note_id, score=None): - logger.debug(note_id) note: Note = Note.get(note_id) if note.note_type == "human": icon = "🀡" diff --git a/stream_app/source.py b/pages/stream_app/source.py similarity index 100% rename from stream_app/source.py rename to pages/stream_app/source.py diff --git a/stream_app/utils.py b/pages/stream_app/utils.py similarity index 85% rename from stream_app/utils.py rename to pages/stream_app/utils.py index fc3800c..69bf78e 100644 --- a/stream_app/utils.py +++ b/pages/stream_app/utils.py @@ -1,4 +1,5 @@ import streamlit as st +from loguru import logger from open_notebook.database.migrate import MigrationManager from open_notebook.graphs.chat import ThreadState, graph @@ -90,7 +91,25 @@ def check_models(): st.stop() -def page_commons(): - version_sidebar() +def handle_error(func): + """Decorator for consistent error handling""" + + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.error(f"Error in {func.__name__}: {str(e)}") + logger.exception(e) + st.error(f"An error occurred: {str(e)}") + + return wrapper + + +def setup_page(title: str, layout="wide", sidebar_state="expanded"): + """Common page setup for all pages""" + st.set_page_config( + page_title=title, layout=layout, initial_sidebar_state=sidebar_state + ) check_migration() check_models() + version_sidebar() From 212d3a33b070e2c2da6b509216762bf0a11919e7 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 22:29:59 -0300 Subject: [PATCH 15/31] improve object typing --- open_notebook/domain/base.py | 13 +- open_notebook/domain/models.py | 125 +++++++++++++++++- open_notebook/domain/notebook.py | 8 +- .../graphs/content_processing/audio.py | 3 +- open_notebook/graphs/tools.py | 15 +-- open_notebook/graphs/utils.py | 2 +- open_notebook/models/__init__.py | 125 +++--------------- 7 files changed, 151 insertions(+), 140 deletions(-) diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index 84a57c9..c33a292 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -55,7 +55,8 @@ class ObjectModel(BaseModel): result = repo_query(f"SELECT * FROM {id}") if result: return cls(**result[0]) - return None + else: + raise NotFoundError(f"{cls.table_name} with id {id} not found") except Exception as e: logger.error(f"Error fetching {cls.table_name} with id {id}: {str(e)}") logger.exception(e) @@ -68,12 +69,12 @@ class ObjectModel(BaseModel): return None def save(self) -> None: - from open_notebook.models import model_manager + from open_notebook.domain.models import model_manager + from open_notebook.models import EmbeddingModel - EMBEDDING_MODEL = model_manager.get_default_model("embedding") + EMBEDDING_MODEL: EmbeddingModel = model_manager.embedding_model try: - logger.debug(f"Validating {self.__class__.__name__}") self.model_validate(self.model_dump(), strict=True) data = self._prepare_save_data() data["updated"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -90,7 +91,7 @@ class ObjectModel(BaseModel): else: data["created"] = ( self.created.strftime("%Y-%m-%d %H:%M:%S") - if type(self.created) == datetime + if isinstance(self.created, datetime) else self.created ) logger.debug(f"Updating record with id {self.id}") @@ -118,8 +119,6 @@ class ObjectModel(BaseModel): def _prepare_save_data(self) -> Dict[str, Any]: data = self.model_dump() - # del data["created"] - # del data["updated"] return {key: value for key, value in data.items() if value is not None} def delete(self) -> bool: diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index 94549b1..cf6c8f9 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -1,7 +1,15 @@ -from typing import ClassVar, Optional +from typing import ClassVar, Dict, Optional from open_notebook.database.repository import repo_query from open_notebook.domain.base import ObjectModel, RecordModel +from open_notebook.models import ( + MODEL_CLASS_MAP, + EmbeddingModel, + LanguageModel, + ModelType, + SpeechToTextModel, + TextToSpeechModel, +) class Model(ObjectModel): @@ -18,7 +26,6 @@ class Model(ObjectModel): return [Model(**model) for model in models] -# todo: future: colocar um cache aqui class DefaultModels(RecordModel): record_id: ClassVar[str] = "open_notebook:default_models" @@ -29,3 +36,117 @@ class DefaultModels(RecordModel): default_speech_to_text_model: Optional[str] = None # default_vision_model: Optional[str] = None default_embedding_model: Optional[str] = None + + +class ModelManager: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ModelManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, "_initialized"): + self._initialized = True + self._model_cache: Dict[str, ModelType] = {} + self._default_models = None + self.refresh_defaults() + + def get_model(self, model_id: str, **kwargs) -> ModelType: + cache_key = f"{model_id}:{str(kwargs)}" + + if cache_key in self._model_cache: + cached_model = self._model_cache[cache_key] + if not isinstance( + cached_model, + (LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel), + ): + raise TypeError( + f"Cached model is of unexpected type: {type(cached_model)}" + ) + return cached_model + + assert model_id, "Model ID cannot be empty" + model: Model = Model.get(model_id) + + if not model: + raise ValueError(f"Model with ID {model_id} not found") + + if not model.type or model.type not in MODEL_CLASS_MAP: + raise ValueError(f"Invalid model type: {model.type}") + + provider_map = MODEL_CLASS_MAP[model.type] + if model.provider not in provider_map: + raise ValueError( + f"Provider {model.provider} not compatible with {model.type} models" + ) + + model_class = provider_map[model.provider] + model_instance = model_class(model_name=model.name, **kwargs) + + # Special handling for language models that need langchain conversion + if model.type == "language": + model_instance = model_instance + + self._model_cache[cache_key] = model_instance + return model_instance + + def refresh_defaults(self): + """Refresh the default models from the database""" + self._default_models = DefaultModels.load() + + @property + def defaults(self) -> DefaultModels: + """Get the default models configuration""" + if not self._default_models: + self.refresh_defaults() + if not self._default_models: + raise RuntimeError("Failed to initialize default models configuration") + return self._default_models + + @property + def embedding_model(self, **kwargs) -> EmbeddingModel: + """Get the default embedding model""" + model = self.get_default_model("embedding", **kwargs) + if not isinstance(model, EmbeddingModel): + raise TypeError(f"Expected EmbeddingModel but got {type(model)}") + return model + + def get_default_model(self, model_type: str, **kwargs) -> ModelType: + """ + Get the default model for a specific type. + + Args: + model_type: The type of model to retrieve (e.g., 'chat', 'embedding', etc.) + **kwargs: Additional arguments to pass to the model constructor + """ + model_id = None + + if model_type == "chat": + model_id = self.defaults.default_chat_model + elif model_type == "transformation": + model_id = ( + self.defaults.default_transformation_model + or self.defaults.default_chat_model + ) + elif model_type == "embedding": + model_id = self.defaults.default_embedding_model + elif model_type == "text_to_speech": + model_id = self.defaults.default_text_to_speech_model + elif model_type == "speech_to_text": + model_id = self.defaults.default_speech_to_text_model + elif model_type == "large_context": + model_id = self.defaults.large_context_model + + if not model_id: + raise ValueError(f"No default model configured for type: {model_type}") + + return self.get_model(model_id, **kwargs) + + def clear_cache(self): + """Clear the model cache""" + self._model_cache.clear() + + +model_manager = ModelManager() diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index 207f5ca..1eb2edd 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -9,13 +9,11 @@ from open_notebook.database.repository import ( repo_query, ) from open_notebook.domain.base import ObjectModel +from open_notebook.domain.models import model_manager from open_notebook.exceptions import ( DatabaseOperationError, InvalidInputError, ) - -# from temp.recursive_toc import graph as toc_graph -from open_notebook.models import model_manager from open_notebook.utils import split_text, surreal_clean @@ -139,7 +137,7 @@ class Source(ObjectModel): raise DatabaseOperationError(e) def vectorize(self) -> None: - EMBEDDING_MODEL = model_manager.get_default_model("embedding") + EMBEDDING_MODEL = model_manager.embedding_model try: if not self.full_text: @@ -190,7 +188,7 @@ class Source(ObjectModel): raise DatabaseOperationError("Failed to search sources") def add_insight(self, insight_type: str, content: str) -> Any: - EMBEDDING_MODEL = model_manager.get_default_model("embedding") + EMBEDDING_MODEL = model_manager.embedding_model if not insight_type or not content: raise InvalidInputError("Insight type and content must be provided") diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index ac81481..be8c441 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -4,9 +4,10 @@ from math import ceil from loguru import logger from pydub import AudioSegment +from open_notebook.domain.models import model_manager from open_notebook.graphs.content_processing.state import SourceState -from open_notebook.models import model_manager +# todo: remove reference to model_manager # future: parallelize the transcription process diff --git a/open_notebook/graphs/tools.py b/open_notebook/graphs/tools.py index 96aeacc..9c3df13 100644 --- a/open_notebook/graphs/tools.py +++ b/open_notebook/graphs/tools.py @@ -3,6 +3,7 @@ from datetime import datetime from langchain.tools import tool +# todo: turn this into a system prompt variable @tool def get_current_timestamp() -> str: """ @@ -10,17 +11,3 @@ def get_current_timestamp() -> str: Returns the current timestamp in the format YYYYMMDDHHmmss. """ return datetime.now().strftime("%Y%m%d%H%M%S") - - -# @tool -# def doc_query(doc_id: str, question: str): -# """ -# name: doc_query -# Use this tool if you need to investigate into a particular document. -# Another LLM will read the document and answer the question that you might have. -# Use this when the user question cannot be answered with the content you have in context. -# """ -# from temp.doc_query import graph - -# result = graph.invoke({"doc_id": doc_id, "question": question}) -# return result["answer"] diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 5d8339c..d67d5a6 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -2,7 +2,7 @@ from langchain.output_parsers import OutputFixingParser from langchain_core.messages import AIMessage from loguru import logger -from open_notebook.models import model_manager +from open_notebook.domain.models import model_manager from open_notebook.prompter import Prompter from open_notebook.utils import token_count diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py index 679f262..7e95d9c 100644 --- a/open_notebook/models/__init__.py +++ b/open_notebook/models/__init__.py @@ -1,6 +1,5 @@ -from typing import Dict, Optional, Union +from typing import Dict, Type, Union -from open_notebook.domain.models import DefaultModels, Model from open_notebook.models.embedding_models import ( EmbeddingModel, GeminiEmbeddingModel, @@ -29,8 +28,12 @@ from open_notebook.models.text_to_speech_models import ( TextToSpeechModel, ) -# Unified model class map with type information -MODEL_CLASS_MAP = { +ModelType = Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel] + + +ProviderMap = Dict[str, Type[ModelType]] + +MODEL_CLASS_MAP: Dict[str, ProviderMap] = { "language": { "ollama": OllamaLanguageModel, "openrouter": OpenRouterLanguageModel, @@ -56,109 +59,11 @@ MODEL_CLASS_MAP = { }, } - -class ModelManager: - _instance = None - _model_cache: Dict[str, object] = {} - _default_models: Optional[DefaultModels] = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super(ModelManager, cls).__new__(cls) - return cls._instance - - def __init__(self): - if not hasattr(self, "_initialized"): - self._initialized = True - self.refresh_defaults() - - def refresh_defaults(self): - """Refresh the default models from the database""" - self._default_models = DefaultModels.load() - - @property - def defaults(self) -> DefaultModels: - """Get the default models configuration""" - if not self._default_models: - self.refresh_defaults() - return self._default_models - - def get_model( - self, model_id: str, **kwargs - ) -> Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]: - """ - Get a model instance based on model_id. Uses caching to avoid recreating instances. - - Args: - model_id: The ID of the model to retrieve - **kwargs: Additional arguments to pass to the model constructor - """ - cache_key = f"{model_id}:{str(kwargs)}" - - if cache_key in self._model_cache: - return self._model_cache[cache_key] - - assert model_id, "Model ID cannot be empty" - model: Model = Model.get(model_id) - - if not model: - raise ValueError(f"Model with ID {model_id} not found") - - if not model.type or model.type not in MODEL_CLASS_MAP: - raise ValueError(f"Invalid model type: {model.type}") - - provider_map = MODEL_CLASS_MAP[model.type] - if model.provider not in provider_map: - raise ValueError( - f"Provider {model.provider} not compatible with {model.type} models" - ) - - model_class = provider_map[model.provider] - model_instance = model_class(model_name=model.name, **kwargs) - - # Special handling for language models that need langchain conversion - if model.type == "language": - model_instance = model_instance - - self._model_cache[cache_key] = model_instance - return model_instance - - def get_default_model( - self, model_type: str, **kwargs - ) -> Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]: - """ - Get the default model for a specific type. - - Args: - model_type: The type of model to retrieve (e.g., 'chat', 'embedding', etc.) - **kwargs: Additional arguments to pass to the model constructor - """ - model_id = None - - if model_type == "chat": - model_id = self.defaults.default_chat_model - elif model_type == "transformation": - model_id = ( - self.defaults.default_transformation_model - or self.defaults.default_chat_model - ) - elif model_type == "embedding": - model_id = self.defaults.default_embedding_model - elif model_type == "text_to_speech": - model_id = self.defaults.default_text_to_speech_model - elif model_type == "speech_to_text": - model_id = self.defaults.default_speech_to_text_model - elif model_type == "large_context": - model_id = self.defaults.large_context_model - - if not model_id: - raise ValueError(f"No default model configured for type: {model_type}") - - return self.get_model(model_id, **kwargs) - - def clear_cache(self): - """Clear the model cache""" - self._model_cache.clear() - - -model_manager = ModelManager() +__all__ = [ + "MODEL_CLASS_MAP", + "EmbeddingModel", + "LanguageModel", + "SpeechToTextModel", + "TextToSpeechModel", + "ModelType", +] From 4f586ad513b34704a919c8c6e3743852e171b04e Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 22:32:03 -0300 Subject: [PATCH 16/31] model_manager import --- pages/3_πŸ”_Search.py | 3 ++- pages/stream_app/utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pages/3_πŸ”_Search.py b/pages/3_πŸ”_Search.py index c0a7b84..a9abad5 100644 --- a/pages/3_πŸ”_Search.py +++ b/pages/3_πŸ”_Search.py @@ -1,7 +1,8 @@ import streamlit as st +from open_notebook.domain.models import model_manager from open_notebook.domain.notebook import text_search, vector_search -from open_notebook.models import EmbeddingModel, model_manager +from open_notebook.models import EmbeddingModel from pages.stream_app.note import note_list_item from pages.stream_app.source import source_list_item from pages.stream_app.utils import setup_page diff --git a/pages/stream_app/utils.py b/pages/stream_app/utils.py index 69bf78e..88f8849 100644 --- a/pages/stream_app/utils.py +++ b/pages/stream_app/utils.py @@ -2,8 +2,8 @@ import streamlit as st from loguru import logger from open_notebook.database.migrate import MigrationManager +from open_notebook.domain.models import model_manager from open_notebook.graphs.chat import ThreadState, graph -from open_notebook.models import model_manager from open_notebook.utils import ( compare_versions, get_installed_version, From 223f1bdaf5b4aa32b6de181632128cd2756a3c7e Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 22:38:21 -0300 Subject: [PATCH 17/31] improve default_models --- open_notebook/domain/models.py | 16 ++++++++++++++++ open_notebook/graphs/content_processing/audio.py | 2 +- open_notebook/graphs/utils.py | 4 ++-- pages/3_πŸ”_Search.py | 3 +-- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index cf6c8f9..d66ad19 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -105,6 +105,22 @@ class ModelManager: raise RuntimeError("Failed to initialize default models configuration") return self._default_models + @property + def speech_to_text(self, **kwargs) -> SpeechToTextModel: + """Get the default speech-to-text model""" + model = self.get_default_model("speech_to_text", **kwargs) + if not isinstance(model, SpeechToTextModel): + raise TypeError(f"Expected SpeechToTextModel but got {type(model)}") + return model + + @property + def text_to_speech(self, **kwargs) -> TextToSpeechModel: + """Get the default text-to-speech model""" + model = self.get_default_model("text_to_speech", **kwargs) + if not isinstance(model, TextToSpeechModel): + raise TypeError(f"Expected TextToSpeechModel but got {type(model)}") + return model + @property def embedding_model(self, **kwargs) -> EmbeddingModel: """Get the default embedding model""" diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index be8c441..3f99277 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -73,7 +73,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None): def extract_audio(data: SourceState): - SPEECH_TO_TEXT_MODEL = model_manager.get_default_model("speech_to_text") + SPEECH_TO_TEXT_MODEL = model_manager.speech_to_text input_audio_path = data.get("file_path") audio_files = [] diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index d67d5a6..9429bca 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -20,11 +20,11 @@ def provision_model(content, config, default_type): logger.debug( f"Using large context model because the content has {tokens} tokens" ) - return model_manager.get_default_model("large_context") + return model_manager.get_default_model("large_context").to_langchain() elif config.get("configurable", {}).get("model_id"): return model_manager.get_model(config.get("configurable", {}).get("model_id")) else: - return model_manager.get_default_model(default_type) + return model_manager.get_default_model(default_type).to_langchain() # todo: turn into a graph diff --git a/pages/3_πŸ”_Search.py b/pages/3_πŸ”_Search.py index a9abad5..4438af5 100644 --- a/pages/3_πŸ”_Search.py +++ b/pages/3_πŸ”_Search.py @@ -2,14 +2,13 @@ import streamlit as st from open_notebook.domain.models import model_manager from open_notebook.domain.notebook import text_search, vector_search -from open_notebook.models import EmbeddingModel from pages.stream_app.note import note_list_item from pages.stream_app.source import source_list_item from pages.stream_app.utils import setup_page setup_page("πŸ” Search") -EMBEDDING_MODEL: EmbeddingModel = model_manager.get_default_model("embedding") +EMBEDDING_MODEL = model_manager.embedding_model # search_tab, ask_tab = st.tabs(["Search", "Ask"]) # notebooks = Notebook.get_all() From 7dc37a3ac75e1d6337c8d14522af89324f9c6bbb Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 22:43:33 -0300 Subject: [PATCH 18/31] model fixes --- open_notebook/graphs/chat.py | 2 +- open_notebook/graphs/utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index c87af21..0c0e517 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -26,7 +26,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict system_prompt = Prompter(prompt_template="chat").render(data=state) payload = [system_prompt] + state.get("messages", []) model = provision_model(str(payload), config, "chat") - ai_message = model.invoke(payload, []) + ai_message = model.invoke(payload) return {"messages": ai_message} diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 9429bca..3f84f3b 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -22,7 +22,9 @@ def provision_model(content, config, default_type): ) return model_manager.get_default_model("large_context").to_langchain() elif config.get("configurable", {}).get("model_id"): - return model_manager.get_model(config.get("configurable", {}).get("model_id")) + return model_manager.get_model( + config.get("configurable", {}).get("model_id") + ).to_langchain() else: return model_manager.get_default_model(default_type).to_langchain() From d9c0c93debacc18d9d45969542638b9dac3e71ac Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 22:50:27 -0300 Subject: [PATCH 19/31] improved typing --- open_notebook/graphs/chat.py | 4 ++-- open_notebook/graphs/utils.py | 33 ++++++++++++--------------------- open_notebook/utils.py | 10 +++++----- 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 0c0e517..e403ace 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -11,7 +11,7 @@ from typing_extensions import TypedDict from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE from open_notebook.domain.notebook import Notebook -from open_notebook.graphs.utils import provision_model +from open_notebook.graphs.utils import provision_langchain_model from open_notebook.prompter import Prompter @@ -25,7 +25,7 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="chat").render(data=state) payload = [system_prompt] + state.get("messages", []) - model = provision_model(str(payload), config, "chat") + model = provision_langchain_model(str(payload), config, "chat") ai_message = model.invoke(payload) return {"messages": ai_message} diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 3f84f3b..a3874e8 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -1,13 +1,14 @@ -from langchain.output_parsers import OutputFixingParser -from langchain_core.messages import AIMessage +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage from loguru import logger from open_notebook.domain.models import model_manager +from open_notebook.models.llms import LanguageModel from open_notebook.prompter import Prompter from open_notebook.utils import token_count -def provision_model(content, config, default_type): +def provision_langchain_model(content, config, default_type) -> BaseChatModel: """ Returns the best model to use based on the context size and on whether there is a specific model being requested in Config. If context > 105_000, returns the large_context_model @@ -20,13 +21,14 @@ def provision_model(content, config, default_type): logger.debug( f"Using large context model because the content has {tokens} tokens" ) - return model_manager.get_default_model("large_context").to_langchain() + model = model_manager.get_default_model("large_context") elif config.get("configurable", {}).get("model_id"): - return model_manager.get_model( - config.get("configurable", {}).get("model_id") - ).to_langchain() + model = model_manager.get_model(config.get("configurable", {}).get("model_id")) else: - return model_manager.get_default_model(default_type).to_langchain() + model = model_manager.get_default_model(default_type) + + assert isinstance(model, LanguageModel), f"Model is not a LanguageModel: {model}" + return model.to_langchain() # todo: turn into a graph @@ -36,23 +38,12 @@ def run_pattern( messages=[], state: dict = {}, parser=None, - output_fixing_model_id=None, -) -> AIMessage: +) -> BaseMessage: system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( data=state ) payload = [system_prompt] + messages - chain = provision_model(str(payload), config, "transformation") - - if parser: - chain = chain | parser - - if output_fixing_model_id and parser: - output_fix_model = model_manager.get_model(output_fixing_model_id) - chain = chain | OutputFixingParser.from_llm( - parser=parser, - llm=output_fix_model, - ) + chain = provision_langchain_model(str(payload), config, "transformation") response = chain.invoke(payload) diff --git a/open_notebook/utils.py b/open_notebook/utils.py index 340762e..86479e2 100644 --- a/open_notebook/utils.py +++ b/open_notebook/utils.py @@ -28,7 +28,7 @@ def split_text(txt: str, chunk=1000, overlap=0, separator=" "): return text_splitter.split_text(txt) -def token_count(input_string): +def token_count(input_string) -> int: """ Count the number of tokens in the input string using the 'o200k_base' encoding. @@ -46,7 +46,7 @@ def token_count(input_string): return token_count -def token_cost(token_count, cost_per_million=0.150): +def token_cost(token_count, cost_per_million=0.150) -> float: """ Calculate the cost of tokens based on the token count and cost per million tokens. @@ -60,11 +60,11 @@ def token_cost(token_count, cost_per_million=0.150): return cost_per_million * (token_count / 1_000_000) -def remove_non_ascii(text): +def remove_non_ascii(text) -> str: return re.sub(r"[^\x00-\x7F]+", "", text) -def remove_non_printable(text): +def remove_non_printable(text) -> str: # Remove control characters, except newlines and tabs text = "".join( char for char in text if unicodedata.category(char)[0] != "C" or char in "\n\t" @@ -74,7 +74,7 @@ def remove_non_printable(text): return re.sub(r"[^\w\s.,!?\-\n\t]", "", text, flags=re.UNICODE) -def surreal_clean(text): +def surreal_clean(text) -> str: """ Clean the input text by removing non-ASCII and non-printable characters, and adjusting colon placement for SurrealDB compatibility. From b4ba3ef4c819da252d32f0fc26bf0d4cf0410c57 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 09:49:11 -0300 Subject: [PATCH 20/31] change model provisioning strategy --- open_notebook/graphs/chat.py | 2 +- open_notebook/graphs/utils.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index e403ace..5e3b4ca 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -25,7 +25,7 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="chat").render(data=state) payload = [system_prompt] + state.get("messages", []) - model = provision_langchain_model(str(payload), config, "chat") + model = provision_langchain_model(str(payload), config, "chat", max_tokens=2000) ai_message = model.invoke(payload) return {"messages": ai_message} diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index a3874e8..07365ea 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -8,7 +8,7 @@ from open_notebook.prompter import Prompter from open_notebook.utils import token_count -def provision_langchain_model(content, config, default_type) -> BaseChatModel: +def provision_langchain_model(content, config, default_type, **kwargs) -> BaseChatModel: """ Returns the best model to use based on the context size and on whether there is a specific model being requested in Config. If context > 105_000, returns the large_context_model @@ -21,11 +21,13 @@ def provision_langchain_model(content, config, default_type) -> BaseChatModel: logger.debug( f"Using large context model because the content has {tokens} tokens" ) - model = model_manager.get_default_model("large_context") + model = model_manager.get_default_model("large_context", **kwargs) elif config.get("configurable", {}).get("model_id"): - model = model_manager.get_model(config.get("configurable", {}).get("model_id")) + model = model_manager.get_model( + config.get("configurable", {}).get("model_id"), **kwargs + ) else: - model = model_manager.get_default_model(default_type) + model = model_manager.get_default_model(default_type, **kwargs) assert isinstance(model, LanguageModel), f"Model is not a LanguageModel: {model}" return model.to_langchain() From 1f23ba549077f4189ddb00f65cec62648c5c1462 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 09:49:45 -0300 Subject: [PATCH 21/31] add autoload to recordmodel --- open_notebook/domain/base.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index c33a292..146589d 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -154,17 +154,21 @@ class ObjectModel(BaseModel): class RecordModel(BaseModel): - record_id: ClassVar[str] = "open_notebook:default_models" + record_id: ClassVar[str] - @classmethod - def load(cls): - result = repo_query(f"SELECT * FROM {cls.record_id};") + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.load() + + def load(self): + result = repo_query(f"SELECT * FROM {self.record_id};") if result: result = result[0] - dm = cls(**result) - return dm - return cls() + for key, value in result.items(): + if hasattr(self, key): + setattr(self, key, value) + return self - @classmethod def update(self, data): repo_update(self.record_id, data) + return self.load() From 62cd5a9dfb892b1f60c0aa07dc0f718e8026f0d7 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 09:50:15 -0300 Subject: [PATCH 22/31] add concept of tools model --- open_notebook/domain/models.py | 7 ++++++- pages/7_βš™οΈ_Settings.py | 18 ++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py index d66ad19..9cb84ca 100644 --- a/open_notebook/domain/models.py +++ b/open_notebook/domain/models.py @@ -36,6 +36,7 @@ class DefaultModels(RecordModel): default_speech_to_text_model: Optional[str] = None # default_vision_model: Optional[str] = None default_embedding_model: Optional[str] = None + default_tools_model: Optional[str] = None class ModelManager: @@ -94,7 +95,7 @@ class ModelManager: def refresh_defaults(self): """Refresh the default models from the database""" - self._default_models = DefaultModels.load() + self._default_models = DefaultModels() @property def defaults(self) -> DefaultModels: @@ -146,6 +147,10 @@ class ModelManager: self.defaults.default_transformation_model or self.defaults.default_chat_model ) + elif model_type == "tools": + model_id = ( + self.defaults.default_tools_model or self.defaults.default_chat_model + ) elif model_type == "embedding": model_id = self.defaults.default_embedding_model elif model_type == "text_to_speech": diff --git a/pages/7_βš™οΈ_Settings.py b/pages/7_βš™οΈ_Settings.py index 0d93594..aad2132 100644 --- a/pages/7_βš™οΈ_Settings.py +++ b/pages/7_βš™οΈ_Settings.py @@ -2,7 +2,7 @@ import os import streamlit as st -from open_notebook.domain.models import DefaultModels, Model +from open_notebook.domain.models import DefaultModels, Model, model_manager from open_notebook.models import MODEL_CLASS_MAP from pages.stream_app.utils import setup_page @@ -118,7 +118,7 @@ def get_selected_index(models, model_id, default=0): with model_defaults_tab: - default_models = DefaultModels.load().model_dump() + default_models = DefaultModels().model_dump() all_models = Model.get_all() text_generation_models = [model for model in all_models if model.type == "language"] @@ -154,7 +154,16 @@ with model_defaults_tab: text_generation_models, default_models.get("default_transformation_model") ), ) - st.caption("You can override this model on individual transformations") + st.divider() + defs["default_tools_model"] = st.selectbox( + "Default Tools Model", + text_generation_models, + format_func=lambda x: x.name, + help="This model will be used for calling tools. Currently, it's best to use Open AI and Anthropic for this.", + index=get_selected_index( + text_generation_models, default_models.get("default_tools_model") + ), + ) st.divider() defs["large_context_model"] = st.selectbox( "Large Context Model", @@ -216,4 +225,5 @@ with model_defaults_tab: for k, v in defs.items(): if v: defs[k] = v.id - DefaultModels.update(defs) + DefaultModels().update(defs) + model_manager.refresh_defaults() From 8398539df86118ded48fdab7ab1ef49b1e6ea9b1 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 09:51:20 -0300 Subject: [PATCH 23/31] add hybrid search to combine text and vector searches --- open_notebook/domain/notebook.py | 145 +++++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 7 deletions(-) diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index 1eb2edd..ee2035e 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -260,23 +260,154 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr except Exception as e: logger.error(f"Error performing text search: {str(e)}") logger.exception(e) - raise DatabaseOperationError("Failed to perform text search") + raise DatabaseOperationError(e) -def vector_search( - keyword: List[float], results: int, source: bool = True, note: bool = True -): +# def hybrid_search( +# keyword_search: List[str], +# embed_search: List[str], +# results: int = 50, +# source: bool = True, +# note: bool = True, +# ): +# EMBEDDING_MODEL = model_manager.embedding_model +# embed1_vector = EMBEDDING_MODEL.embed(embed1) if embed1 else None + + +# todo: mover o embedding pra ca +def vector_search(keyword: str, results: int, source: bool = True, note: bool = True): if not keyword: raise InvalidInputError("Search keyword cannot be empty") try: + EMBEDDING_MODEL = model_manager.embedding_model + embed = EMBEDDING_MODEL.embed(keyword) results = repo_query( """ - SELECT * FROM fn::vector_search($keyword, $results, $source, $note); + SELECT * FROM fn::vector_search($embed, $results, $source, $note, 0.15); """, - {"keyword": keyword, "results": results, "source": source, "note": note}, + {"embed": embed, "results": results, "source": source, "note": note}, ) return results except Exception as e: logger.error(f"Error performing vector search: {str(e)}") logger.exception(e) - raise DatabaseOperationError("Failed to perform vector search") + raise DatabaseOperationError(e) + + +def hybrid_search( + keyword_search: List[str], + embed_search: List[str], + results: int = 50, + source: bool = True, + note: bool = True, + max_chunks_per_doc: int = 3, + min_results_per_query: int = 3, +) -> Dict[str, List[Dict]]: + if not keyword_search and not embed_search: + raise InvalidInputError("At least one search term required") + + # Process keyword searches + all_keyword_results = {} # Dictionary to store results per keyword + for keyword in keyword_search: + try: + search_results = text_search(keyword, results, source, note) + # Sort results by relevance + sorted_results = sorted( + search_results, key=lambda x: x.get("relevance", 0), reverse=True + ) + # Group by parent_id and limit chunks per document + seen_parent_ids = {} + filtered_results = [] + for result in sorted_results: + parent_id = result["parent_id"] + if parent_id not in seen_parent_ids: + seen_parent_ids[parent_id] = 1 + filtered_results.append(result) + elif seen_parent_ids[parent_id] < max_chunks_per_doc: + seen_parent_ids[parent_id] += 1 + filtered_results.append(result) + all_keyword_results[keyword] = filtered_results + except Exception as e: + logger.warning(f"Error in keyword search for term '{keyword}': {str(e)}") + continue + + # Ensure minimum results from each keyword query + keyword_results = [] + remaining_slots = results + + # First pass: add minimum results from each query + for keyword, query_results in all_keyword_results.items(): + keyword_results.extend(query_results[:min_results_per_query]) + remaining_slots -= min(len(query_results), min_results_per_query) + + # Second pass: fill remaining slots with best results + all_remaining = [] + for keyword, query_results in all_keyword_results.items(): + all_remaining.extend(query_results[min_results_per_query:]) + + # Sort remaining by relevance and add until we hit the limit + all_remaining = sorted( + all_remaining, key=lambda x: x.get("relevance", 0), reverse=True + ) + seen_ids = {r["id"] for r in keyword_results} + for result in all_remaining: + if remaining_slots <= 0: + break + if result["id"] not in seen_ids: + keyword_results.append(result) + seen_ids.add(result["id"]) + remaining_slots -= 1 + + # Process vector searches with the same approach + all_vector_results = {} # Dictionary to store results per embedding + for embed in embed_search: + try: + search_results = vector_search(embed, results, source, note) + # Sort results by similarity + sorted_results = sorted( + search_results, key=lambda x: x.get("similarity", 0), reverse=True + ) + # Group by parent_id and limit chunks per document + seen_parent_ids = {} + filtered_results = [] + for result in sorted_results: + parent_id = result["parent_id"] + if parent_id not in seen_parent_ids: + seen_parent_ids[parent_id] = 1 + filtered_results.append(result) + elif seen_parent_ids[parent_id] < max_chunks_per_doc: + seen_parent_ids[parent_id] += 1 + filtered_results.append(result) + all_vector_results[embed] = filtered_results + except Exception as e: + logger.warning(f"Error in vector search for term '{embed}': {str(e)}") + continue + + # Ensure minimum results from each vector query + vector_results = [] + remaining_slots = results + + # First pass: add minimum results from each query + for embed, query_results in all_vector_results.items(): + vector_results.extend(query_results[:min_results_per_query]) + remaining_slots -= min(len(query_results), min_results_per_query) + + # Second pass: fill remaining slots with best results + all_remaining = [] + for embed, query_results in all_vector_results.items(): + all_remaining.extend(query_results[min_results_per_query:]) + + # Sort remaining by similarity and add until we hit the limit + all_remaining = sorted( + all_remaining, key=lambda x: x.get("similarity", 0), reverse=True + ) + seen_ids = {r["id"] for r in vector_results} + for result in all_remaining: + if remaining_slots <= 0: + break + if result["id"] not in seen_ids: + vector_results.append(result) + seen_ids.add(result["id"]) + remaining_slots -= 1 + + return {"keyword_results": keyword_results, "vector_results": vector_results} From 56e745d668a5aea9accc557e751e3b8afe7552f3 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 09:53:09 -0300 Subject: [PATCH 24/31] improve search functions --- migrations/3.surrealql | 139 ++++++++++++++++++++++++++++++ migrations/3_down.surrealql | 105 ++++++++++++++++++++++ open_notebook/database/migrate.py | 2 + 3 files changed, 246 insertions(+) create mode 100644 migrations/3.surrealql create mode 100644 migrations/3_down.surrealql diff --git a/migrations/3.surrealql b/migrations/3.surrealql new file mode 100644 index 0000000..73b79a7 --- /dev/null +++ b/migrations/3.surrealql @@ -0,0 +1,139 @@ +REMOVE FUNCTION fn::vector_search; + +DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_count: int, $sources: bool, $show_notes: bool, $min_similarity: float) { + let $source_embedding_search = + IF $sources {( + SELECT + id, + source.title as title, + content, + source.id as parent_id, + vector::similarity::cosine(embedding, $query) as similarity + FROM source_embedding + WHERE vector::similarity::cosine(embedding, $query) >= $min_similarity + ORDER BY similarity DESC + LIMIT $match_count + )} + ELSE { [] }; + + -- Busca em source_insight com threshold + let $source_insight_search = + IF $sources {( + SELECT + id, + insight_type + ' - ' + source.title as title, + content, + source.id as parent_id, + vector::similarity::cosine(embedding, $query) as similarity + FROM source_insight + WHERE vector::similarity::cosine(embedding, $query) >= $min_similarity + ORDER BY similarity DESC + LIMIT $match_count + )} + ELSE { [] }; + + + let $note_content_search = + IF $show_notes {( + SELECT + id, + title, + content, + id as parent_id, + vector::similarity::cosine(embedding, $query) as similarity + FROM note + WHERE vector::similarity::cosine(embedding, $query) >= $min_similarity + ORDER BY similarity DESC + LIMIT $match_count + )} + ELSE { [] }; + + + let $all_results = array::union( + array::union($source_embedding_search, $source_insight_search), + $note_content_search + ); + + + RETURN ( + SELECT + id, title, content, parent_id, + math::max(similarity) as similarity + FROM $all_results + GROUP BY id + ORDER BY similarity DESC + LIMIT $match_count + ); +}; + + +REMOVE FUNCTION fn::text_search; + + + DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) { + + let $source_title_search = + IF $sources {( + SELECT id, title, + search::highlight('`', '`', 1) as content, + id as parent_id, + math::max(search::score(1)) AS relevance + FROM source + WHERE title @1@ $query_text + GROUP BY id)} + ELSE { [] }; + + let $source_embedding_search = + IF $sources {( + SELECT id as id, source.title as title, search::highlight('`', '`', 1) as content, source.id as parent_id, math::max(search::score(1)) AS relevance + FROM source_embedding + WHERE content @1@ $query_text + GROUP BY id)} + ELSE { [] }; + + let $source_full_search = + IF $sources {( + SELECT source.id as id, source.title as title, search::highlight('`', '`', 1) as content, source.id as parent_id, math::max(search::score(1)) AS relevance + FROM source + WHERE full_text @1@ $query_text + GROUP BY id)} + ELSE { [] }; + + let $source_insight_search = + IF $sources {( + SELECT id, insight_type + " - " + source.title as title, search::highlight('`', '`', 1) as content, source.id as parent_id, math::max(search::score(1)) AS relevance + FROM source_insight + WHERE content @1@ $query_text + GROUP BY id)} + ELSE { [] }; + + let $note_title_search = + IF $show_notes {( + SELECT id, title, search::highlight('`', '`', 1) as content, id as parent_id, math::max(search::score(1)) AS relevance + FROM note + WHERE title @1@ $query_text + GROUP BY id)} + ELSE { [] }; + + let $note_content_search = + IF $show_notes {( + SELECT id, title, search::highlight('`', '`', 1) as content, id as parent_id, math::max(search::score(1)) AS relevance + FROM note + WHERE content @1@ $query_text + GROUP BY id)} + ELSE { [] }; + + let $source_chunk_results = array::union($source_embedding_search, $source_full_search); + + let $source_asset_results = array::union($source_title_search, $source_insight_search); + + let $source_results = array::union($source_chunk_results, $source_asset_results ); + let $note_results = array::union($note_title_search, $note_content_search ); + let $final_results = array::union($source_results, $note_results ); + + RETURN (SELECT id, title, content, parent_id, math::max(relevance) as relevance from $final_results + where id is not None +group by id, title, content, parent_id ORDER BY relevance DESC LIMIT $match_count); + + +}; diff --git a/migrations/3_down.surrealql b/migrations/3_down.surrealql new file mode 100644 index 0000000..aaab4d9 --- /dev/null +++ b/migrations/3_down.surrealql @@ -0,0 +1,105 @@ +REMOVE FUNCTION fn::vector_search; + + +DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_count: int, $sources:bool, $show_notes:bool) { + + let $source_embedding_search = + IF $sources {( + SELECT source as item_id, content, vector::similarity::cosine(embedding, $query) as similarity + FROM source_embedding LIMIT $match_count)} + ELSE { [] }; + + + let $source_insight_search = + IF $sources {( + SELECT source as item_id, content, vector::similarity::cosine(embedding, $query) as similarity + FROM source_insight LIMIT $match_count)} + ELSE { [] }; + + + let $note_content_search = + IF $show_notes {( + SELECT id as item_id, content, vector::similarity::cosine(embedding, $query) as similarity + FROM note LIMIT $match_count)} + + ELSE { [] }; + + let $source_chunk_results = array::union($source_embedding_search, $source_insight_search); + + let $source_results = array::union($source_chunk_results, $source_insight_search); + + let $note_results = $note_content_search; + let $final_results = array::union($source_results, $note_results ); + + RETURN (SELECT item_id, math::max(similarity) as similarity from $final_results + group by item_id ORDER BY similarity DESC LIMIT $match_count); + + +}; + +REMOVE FUNCTION fn::text_search; + + +DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) { + + let $source_title_search = + IF $sources {( + SELECT id as item_id, math::max(search::score(1)) AS relevance + FROM source + WHERE title @1@ $query_text + GROUP BY item_id)} + ELSE { [] }; + + let $source_embedding_search = + IF $sources {( + SELECT source as item_id, math::max(search::score(1)) AS relevance + FROM source_embedding + WHERE content @1@ $query_text + GROUP BY item_id)} + ELSE { [] }; + + let $source_full_search = + IF $sources {( + SELECT source as item_id, math::max(search::score(1)) AS relevance + FROM source + WHERE full_text @1@ $query_text + GROUP BY item_id)} + ELSE { [] }; + + let $source_insight_search = + IF $sources {( + SELECT source as item_id, math::max(search::score(1)) AS relevance + FROM source_insight + WHERE content @1@ $query_text + GROUP BY item_id)} + ELSE { [] }; + + let $note_title_search = + IF $show_notes {( + SELECT id as item_id, math::max(search::score(1)) AS relevance + FROM note + WHERE title @1@ $query_text + GROUP BY item_id)} + ELSE { [] }; + + let $note_content_search = + IF $show_notes {( + SELECT id as item_id, math::max(search::score(1)) AS relevance + FROM note + WHERE content @1@ $query_text + GROUP BY item_id)} + ELSE { [] }; + + let $source_chunk_results = array::union($source_embedding_search, $source_full_search); + + let $source_asset_results = array::union($source_title_search, $source_insight_search); + + let $source_results = array::union($source_chunk_results, $source_asset_results ); + let $note_results = array::union($note_title_search, $note_content_search ); + let $final_results = array::union($source_results, $note_results ); + + RETURN (SELECT item_id, math::max(relevance) as relevance from $final_results + group by item_id ORDER BY relevance DESC LIMIT $match_count); + + +}; diff --git a/open_notebook/database/migrate.py b/open_notebook/database/migrate.py index f890091..085caf4 100644 --- a/open_notebook/database/migrate.py +++ b/open_notebook/database/migrate.py @@ -21,12 +21,14 @@ class MigrationManager: self.up_migrations = [ Migration.from_file("migrations/1.surrealql"), Migration.from_file("migrations/2.surrealql"), + Migration.from_file("migrations/3.surrealql"), ] self.down_migrations = [ Migration.from_file( "migrations/1_down.surrealql", ), Migration.from_file("migrations/2_down.surrealql"), + Migration.from_file("migrations/3_down.surrealql"), ] self.runner = MigrationRunner( up_migrations=self.up_migrations, From 418c67f69fcb3be54b39b55b05ef8150bee7f85f Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 09:53:49 -0300 Subject: [PATCH 25/31] add search and rag functions in beta --- .../graphs/content_processing/__init__.py | 9 --- open_notebook/graphs/rag.py | 44 ++++++++++++ open_notebook/graphs/tools.py | 14 ++++ pages/3_πŸ”_Ask_and_Search.py | 67 +++++++++++++++++++ pages/3_πŸ”_Search.py | 66 ------------------ prompts/rag.jinja | 59 ++++++++++++++++ 6 files changed, 184 insertions(+), 75 deletions(-) create mode 100644 open_notebook/graphs/rag.py create mode 100644 pages/3_πŸ”_Ask_and_Search.py delete mode 100644 pages/3_πŸ”_Search.py create mode 100644 prompts/rag.jinja diff --git a/open_notebook/graphs/content_processing/__init__.py b/open_notebook/graphs/content_processing/__init__.py index 2c772dc..915da23 100644 --- a/open_notebook/graphs/content_processing/__init__.py +++ b/open_notebook/graphs/content_processing/__init__.py @@ -48,15 +48,6 @@ def file_type(state: SourceState): return return_dict -# def _get_title(url): -# """ -# Get the content of a URL -# """ -# response = extract_url(dict(url=url)) -# if "title" in response: -# return response["title"] - - def file_type_edge(data: SourceState): assert data.get("identified_type"), "Type not identified" identified_type = data["identified_type"] diff --git a/open_notebook/graphs/rag.py b/open_notebook/graphs/rag.py new file mode 100644 index 0000000..24dc435 --- /dev/null +++ b/open_notebook/graphs/rag.py @@ -0,0 +1,44 @@ +from typing import Annotated + +from langchain_core.runnables import ( + RunnableConfig, +) +from langgraph.graph import START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.prebuilt import ToolNode, tools_condition +from typing_extensions import TypedDict + +from open_notebook.graphs.tools import repository_search +from open_notebook.graphs.utils import provision_langchain_model +from open_notebook.prompter import Prompter + +tools = [repository_search] +tool_node = ToolNode(tools) + + +class ThreadState(TypedDict): + messages: Annotated[list, add_messages] + # notebook: Optional[Notebook] + # context: Optional[str] + # context_config: Optional[dict] + + +def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: + system_prompt = Prompter(prompt_template="rag").render(data=state) + payload = [system_prompt] + state.get("messages", []) + model = provision_langchain_model(str(payload), config, "tools", max_tokens=2000) + model = model.bind_tools(tools) + ai_message = model.invoke(payload) + return {"messages": ai_message} + + +agent_state = StateGraph(ThreadState) +agent_state.add_node("agent", call_model_with_messages) +agent_state.add_node("tools", tool_node) +agent_state.add_edge(START, "agent") +agent_state.add_conditional_edges( + "agent", + tools_condition, +) +agent_state.add_edge("tools", "agent") +graph = agent_state.compile() diff --git a/open_notebook/graphs/tools.py b/open_notebook/graphs/tools.py index 9c3df13..620fac4 100644 --- a/open_notebook/graphs/tools.py +++ b/open_notebook/graphs/tools.py @@ -1,7 +1,10 @@ from datetime import datetime +from typing import List from langchain.tools import tool +from open_notebook.domain.notebook import hybrid_search + # todo: turn this into a system prompt variable @tool @@ -11,3 +14,14 @@ def get_current_timestamp() -> str: Returns the current timestamp in the format YYYYMMDDHHmmss. """ return datetime.now().strftime("%Y%m%d%H%M%S") + + +@tool +def repository_search(keyword_searches: List[str], vector_searches: List[str]) -> str: + """ + name: repository_search + Makes a search in the content repository for the given query. + keyword_searches: List[str] - A list of search terms to search for using keyword search. + vector_searches: List[str] - A list of search terms to search for using vector search. + """ + return hybrid_search(keyword_searches, vector_searches, 20) diff --git a/pages/3_πŸ”_Ask_and_Search.py b/pages/3_πŸ”_Ask_and_Search.py new file mode 100644 index 0000000..d70a3e9 --- /dev/null +++ b/pages/3_πŸ”_Ask_and_Search.py @@ -0,0 +1,67 @@ +import streamlit as st + +from open_notebook.domain.models import Model +from open_notebook.domain.notebook import text_search, vector_search +from open_notebook.graphs.rag import graph as rag_graph +from pages.stream_app.utils import setup_page + +setup_page("πŸ” Search") + +ask_tab, search_tab = st.tabs(["Ask Your Knowledge Base (beta)", "Search"]) + +if "search_results" not in st.session_state: + st.session_state["search_results"] = [] + + +def results_card(item): + score = item.get("relevance", item.get("similarity", item.get("score", 0))) + with st.expander(f"[{score:.2f}] **{item['title']}**"): + st.markdown(f"**{item['content']}**") + st.write(item["id"]) + st.write(item["parent_id"]) + + +with ask_tab: + st.subheader("Ask Your Knowledge Base (beta)") + st.caption( + "The LLM will answer your query based on the documents in your knowledge base. " + ) + st.warning( + "This functionality requires the use of Tools and, at this moment, works well with Open AI and Anthropic models only." + ) + question = st.text_input("Question", "") + models = Model.get_models_by_type("language") + model: Model = st.selectbox("Model", models, format_func=lambda x: x.name) + if st.button("Ask"): + st.write(f"Searching for {question}") + messages = [question] + rag_results = rag_graph.invoke( + dict( + messages=messages + ), # config=dict(configurable=dict(model_id=model.id)) + ) + st.markdown(rag_results["messages"][-1].content) + with st.expander("Details (for debugging)"): + st.json(rag_results) + +with search_tab: + with st.container(border=True): + st.subheader("πŸ” Search") + st.caption("Search your knowledge base for specific keywords or concepts") + search_term = st.text_input("Search", "") + search_type = st.radio("Search Type", ["Text Search", "Vector Search"]) + search_sources = st.checkbox("Search Sources", value=True) + search_notes = st.checkbox("Search Notes", value=True) + if st.button("Search"): + if search_type == "Text Search": + st.write(f"Searching for {search_term}") + st.session_state["search_results"] = text_search( + search_term, 100, search_sources, search_notes + ) + elif search_type == "Vector Search": + st.write(f"Searching for {search_term}") + st.session_state["search_results"] = vector_search( + search_term, 100, search_sources, search_notes + ) + for item in st.session_state["search_results"]: + results_card(item) diff --git a/pages/3_πŸ”_Search.py b/pages/3_πŸ”_Search.py deleted file mode 100644 index 4438af5..0000000 --- a/pages/3_πŸ”_Search.py +++ /dev/null @@ -1,66 +0,0 @@ -import streamlit as st - -from open_notebook.domain.models import model_manager -from open_notebook.domain.notebook import text_search, vector_search -from pages.stream_app.note import note_list_item -from pages.stream_app.source import source_list_item -from pages.stream_app.utils import setup_page - -setup_page("πŸ” Search") - -EMBEDDING_MODEL = model_manager.embedding_model - -# search_tab, ask_tab = st.tabs(["Search", "Ask"]) -# notebooks = Notebook.get_all() - -if "search_results" not in st.session_state: - st.session_state["search_results"] = [] - -# with search_tab: -with st.container(border=True): - st.subheader("πŸ” Search") - st.caption("Search your knowledge base for specific keywords or concepts") - search_term = st.text_input("Search", "") - search_type = st.radio("Search Type", ["Text Search", "Vector Search"]) - search_sources = st.checkbox("Search Sources", value=True) - search_notes = st.checkbox("Search Notes", value=True) - if st.button("Search"): - if search_type == "Text Search": - st.write(f"Searching for {search_term}") - st.session_state["search_results"] = text_search( - search_term, 100, search_sources, search_notes - ) - elif search_type == "Vector Search": - st.write(f"Searching for {search_term}") - embed_query = EMBEDDING_MODEL.embed(search_term) - st.session_state["search_results"] = vector_search( - embed_query, 100, search_sources, search_notes - ) - for item in st.session_state["search_results"]: - score = item.get("relevance", item.get("similarity", 0)) - if item.get("item_id"): - if "source:" in item["item_id"]: - source_list_item(item["item_id"], score) - elif "note:" in item["item_id"]: - note_list_item(item["item_id"], score) - -# coming soon -# with ask_tab: -# with st.form(key="ask_form"): -# st.subheader("Ask Your Knowledge Base") -# st.caption("Let the LLM formulate an answer based on your query") -# question = st.text_input("Your question", "") - -# notebooks = st.multiselect( -# "Notebooks", -# notebooks, -# notebooks, -# format_func=lambda x: x.name, -# ) -# search_sources = st.multiselect( -# "Use Sources", -# ["Sources", "Notes"], -# ["Sources", "Notes"], -# ) -# if st.form_submit_button("Search"): -# st.write(f"Searching for {search_term}") diff --git a/prompts/rag.jinja b/prompts/rag.jinja new file mode 100644 index 0000000..3d8d057 --- /dev/null +++ b/prompts/rag.jinja @@ -0,0 +1,59 @@ +# SYSTEM ROLE +You are a cognitive study assistant that helps users research and learn by engaging in focused discussions about documents in their workspace. + +You have access to a search tool that you can use in order to reply to the user query. + +The tool accepts 2 arrays as parameters: + +- keyword_searches: List[str] - A list of search terms to search for using keyword search. +- vector_searches: List[str] - A list of search terms to search for using vector search. + +It's very important that your response contains references to the searched documents so the user can follow-up and read more about the topic. The way you do that is by adding the id of the specific document in between brackets like this: [document_id]. + +# EXAMPLE + +User: Can you tell me more about the concept of "Deep Learning"? + +Assistant: Deep learning is a subset of machine learning in artificial intelligence (AI) that enables networks to learn unsupervised from unstructured or unlabeled data. [note:iuiodadalknda]. It can also be categorized into three main types: supervised, unsupervised, and reinforcement learning. [insight:adadadadadadad]. + +Please note, "note:iuiodadalknda" and "insight:adadadadadadad" are examples of document IDs with different prefixes. You should not make up document IDs or copy the IDs from this example. You should use the IDs of the documents that you have access to through the search tool. + +# IMPORTANT + +- Do not make up documents or document ids. Only use the ids of the documents that you have access through the query you made. +- The ID is composed of the type of document and a random string, such as "source:randomstring", "note:randomstring", or "insight:randomstring". There are various types of documents, including notes, insights, and sources. **Always use the complete ID exactly as it is provided, including its type prefix. Do not add, remove, or modify any part of the ID.** +- Do not assume or change the type prefix of any document ID. If a document ID is "note:xyz", use it exactly as "note:xyz". Do not change it to "source:xyz" or any other variation. +- **Use document IDs exactly as they are returned from the search tool. Do not add any prefixes or modify them in any way.** + + +{# +You are a cognitive study assistant designed to help users research and learn by engaging in focused discussions about documents in their workspace. Your primary goal is to provide informative, accurate responses to user queries while properly citing relevant documents from the available search tool. + +To answer this question effectively, you have access to a search tool with the following parameters: +- keyword_searches: List[str] - A list of search terms for keyword search +- vector_searches: List[str] - A list of search terms for vector search + +Follow these steps to formulate your response: + +1. Analyze the user's question and determine appropriate search terms. +2. Use the search tool to find relevant information. +3. Carefully review the search results, paying close attention to document IDs and content relevance. +4. Compose a clear, informative response that directly addresses the user's question. +5. Include relevant document citations using the exact document IDs provided by the search tool. +6. Review your response for accuracy and relevance before delivering it to the user. + +Important guidelines: +- Always use the complete document ID as provided by the search tool, including its type prefix (e.g., "note:", "insight:", "source:"). +- Do not make up or modify document IDs in any way. +- Ensure that each citation is directly relevant to the information it supports. +- Prioritize accuracy and relevance in your search strategy and response composition. + +Before composing your final response, wrap your thought process in tags to analyze the question, plan your search strategy, and evaluate the search results. This will help ensure that you retrieve the most relevant information and use the correct document IDs in your citations. Include the following steps: +a. Analyze the question and identify key concepts +b. Plan search strategy (both keyword and vector searches) +c. Evaluate search results and note relevant document IDs +d. Outline the main points for the response + +Your final response should be conversational in tone, directly addressing the user's question while seamlessly incorporating document citations. Use square brackets with the full document ID for each citation, like this: [document_id]. + +Remember, the quality and accuracy of your response, including proper document citations, are crucial for helping the user in their research and learning process. #} \ No newline at end of file From 3be1ecae8a5111e345c33c45ab6aae5754470663 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 15:06:13 -0300 Subject: [PATCH 26/31] improve text splitter --- open_notebook/utils.py | 60 ++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/open_notebook/utils.py b/open_notebook/utils.py index 86479e2..b25a79c 100644 --- a/open_notebook/utils.py +++ b/open_notebook/utils.py @@ -5,29 +5,10 @@ from urllib.parse import urlparse import requests import tomli -from langchain_text_splitters import CharacterTextSplitter +from langchain_text_splitters import RecursiveCharacterTextSplitter from packaging.version import parse as parse_version -def split_text(txt: str, chunk=1000, overlap=0, separator=" "): - """ - Split the input text into chunks. - - Args: - txt (str): The input text to be split. - chunk (int): The size of each chunk. Default is 1000. - overlap (int): The number of characters to overlap between chunks. Default is 0. - separator (str): The separator to use when splitting the text. Default is " ". - - Returns: - list: A list of text chunks. - """ - text_splitter = CharacterTextSplitter( - chunk_size=chunk, chunk_overlap=overlap, separator=separator - ) - return text_splitter.split_text(txt) - - def token_count(input_string) -> int: """ Count the number of tokens in the input string using the 'o200k_base' encoding. @@ -60,15 +41,54 @@ def token_cost(token_count, cost_per_million=0.150) -> float: return cost_per_million * (token_count / 1_000_000) +def split_text(txt: str, chunk_size=500): + """ + Split the input text into chunks. + + Args: + txt (str): The input text to be split. + chunk (int): The size of each chunk. Default is 1000. + overlap (int): The number of characters to overlap between chunks. Default is 0. + separator (str): The separator to use when splitting the text. Default is " ". + + Returns: + list: A list of text chunks. + """ + overlap = int(chunk_size * 0.15) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=overlap, + length_function=token_count, + separators=[ + "\n\n", + "\n", + ".", + ",", + " ", + "\u200b", # Zero-width space + "\uff0c", # Fullwidth comma + "\u3001", # Ideographic comma + "\uff0e", # Fullwidth full stop + "\u3002", # Ideographic full stop + "", + ], + ) + return text_splitter.split_text(txt) + + def remove_non_ascii(text) -> str: return re.sub(r"[^\x00-\x7F]+", "", text) def remove_non_printable(text) -> str: + # Replace any special Unicode whitespace characters with a regular space + text = re.sub(r"[\u2000-\u200B\u202F\u205F\u3000]", " ", text) + # Remove control characters, except newlines and tabs text = "".join( char for char in text if unicodedata.category(char)[0] != "C" or char in "\n\t" ) + # Replace non-breaking spaces with regular spaces text = text.replace("\xa0", " ").strip() # Keep letters (including accented ones), numbers, spaces, newlines, tabs, and basic punctuation return re.sub(r"[^\w\s.,!?\-\n\t]", "", text, flags=re.UNICODE) From 0f2216207be2bb5bd8eeeec0ad4e259228304cf1 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 15:08:14 -0300 Subject: [PATCH 27/31] enable multiple chat sessions --- migrations/3.surrealql | 14 +++- migrations/3_down.surrealql | 5 ++ open_notebook/domain/notebook.py | 121 +++++++++++++++++-------------- pages/2_πŸ“’_Notebooks.py | 83 ++++++++++++--------- pages/stream_app/chat.py | 89 ++++++++++++++++++----- pages/stream_app/note.py | 26 +++---- pages/stream_app/source.py | 9 ++- pages/stream_app/utils.py | 62 ++++++++++++++-- 8 files changed, 276 insertions(+), 133 deletions(-) diff --git a/migrations/3.surrealql b/migrations/3.surrealql index 73b79a7..f2a067f 100644 --- a/migrations/3.surrealql +++ b/migrations/3.surrealql @@ -1,4 +1,11 @@ -REMOVE FUNCTION fn::vector_search; + +DEFINE TABLE IF NOT EXISTS chat_session SCHEMALESS; + +DEFINE TABLE IF NOT EXISTS refers_to +TYPE RELATION +FROM chat_session TO notebook; + +REMOVE FUNCTION IF EXISTS fn::vector_search; DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_count: int, $sources: bool, $show_notes: bool, $min_similarity: float) { let $source_embedding_search = @@ -16,7 +23,6 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_cou )} ELSE { [] }; - -- Busca em source_insight com threshold let $source_insight_search = IF $sources {( SELECT @@ -67,10 +73,10 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_cou }; -REMOVE FUNCTION fn::text_search; +REMOVE FUNCTION IF EXISTS fn::text_search; - DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) { +DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) { let $source_title_search = IF $sources {( diff --git a/migrations/3_down.surrealql b/migrations/3_down.surrealql index aaab4d9..b8438e0 100644 --- a/migrations/3_down.surrealql +++ b/migrations/3_down.surrealql @@ -1,3 +1,8 @@ +REMOVE TABLE IF EXISTS chat_session; + +REMOVE TABLE IF EXISTS refers_to; + + REMOVE FUNCTION fn::vector_search; diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index ee2035e..86944f8 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -1,11 +1,9 @@ -import os from typing import Any, ClassVar, Dict, List, Literal, Optional from loguru import logger from pydantic import BaseModel, Field, field_validator from open_notebook.database.repository import ( - repo_create, repo_query, ) from open_notebook.domain.base import ObjectModel @@ -68,6 +66,27 @@ class Notebook(ObjectModel): logger.exception(e) raise DatabaseOperationError(e) + @property + def chat_sessions(self) -> List["ChatSession"]: + try: + srcs = repo_query(f""" + select * from ( + select + <- chat_session as chat_session + from refers_to + where out={self.id} + fetch chat_session + ) + order by chat_session.updated desc + """) + return ( + [ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else [] + ) + except Exception as e: + logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}") + logger.exception(e) + raise DatabaseOperationError(e) + class Asset(BaseModel): file_path: Optional[str] = None @@ -99,6 +118,22 @@ class Source(ObjectModel): else: return dict(id=self.id, title=self.title, insights=self.insights) + @property + def embedded_chunks(self) -> int: + try: + result = repo_query( + f""" + select count() as chunks from source_embedding where source={self.id} GROUP ALL + """ + ) + if len(result) == 0: + return 0 + return result[0]["chunks"] + except Exception as e: + logger.error(f"Error fetching insights for source {self.id}: {str(e)}") + logger.exception(e) + raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}") + @property def insights(self) -> List[SourceInsight]: try: @@ -118,24 +153,6 @@ class Source(ObjectModel): raise InvalidInputError("Notebook ID must be provided") return self.relate("reference", notebook_id) - def save_chunks(self, text: str) -> None: - if not text: - raise InvalidInputError("Text cannot be empty") - try: - chunks = split_text(text, chunk=500000, overlap=1000) - logger.debug(f"Split into {len(chunks)} chunks") - for i, chunk in enumerate(chunks): - logger.debug(f"Saving chunk {i}") - data = {"source": self.id, "order": i, "content": surreal_clean(chunk)} - repo_create( - "source_chunk", - data, - ) - except Exception as e: - logger.exception(e) - logger.error(f"Error saving chunks for source {self.id}: {str(e)}") - raise DatabaseOperationError(e) - def vectorize(self) -> None: EMBEDDING_MODEL = model_manager.embedding_model @@ -144,8 +161,6 @@ class Source(ObjectModel): return chunks = split_text( self.full_text, - chunk=int(os.environ.get("EMBEDDING_CHUNK_SIZE", 1000)), - overlap=int(os.environ.get("EMBEDDING_CHUNK_OVERLAP", 1000)), ) logger.debug(f"Split into {len(chunks)} chunks") @@ -166,26 +181,26 @@ class Source(ObjectModel): logger.exception(e) raise DatabaseOperationError(e) - @classmethod - def search(cls, query: str) -> List[Dict[str, Any]]: - if not query: - raise InvalidInputError("Search query cannot be empty") - try: - result = repo_query( - """ - SELECT * omit full_text - FROM source - WHERE string::lowercase(title) CONTAINS $query or title @@ $query - OR string::lowercase(summary) CONTAINS $query or summary @@ $query - OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query - """, - {"query": query}, - ) - return result - except Exception as e: - logger.error(f"Error searching sources: {str(e)}") - logger.exception(e) - raise DatabaseOperationError("Failed to search sources") + # @classmethod + # def search(cls, query: str) -> List[Dict[str, Any]]: + # if not query: + # raise InvalidInputError("Search query cannot be empty") + # try: + # result = repo_query( + # """ + # SELECT * omit full_text + # FROM source + # WHERE string::lowercase(title) CONTAINS $query or title @@ $query + # OR string::lowercase(summary) CONTAINS $query or summary @@ $query + # OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query + # """, + # {"query": query}, + # ) + # return result + # except Exception as e: + # logger.error(f"Error searching sources: {str(e)}") + # logger.exception(e) + # raise DatabaseOperationError("Failed to search sources") def add_insight(self, insight_type: str, content: str) -> Any: EMBEDDING_MODEL = model_manager.embedding_model @@ -246,6 +261,16 @@ class Note(ObjectModel): return self.content +class ChatSession(ObjectModel): + table_name: ClassVar[str] = "chat_session" + title: Optional[str] = None + + def relate_to_notebook(self, notebook_id: str) -> Any: + if not notebook_id: + raise InvalidInputError("Notebook ID must be provided") + return self.relate("refers_to", notebook_id) + + def text_search(keyword: str, results: int, source: bool = True, note: bool = True): if not keyword: raise InvalidInputError("Search keyword cannot be empty") @@ -263,18 +288,6 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr raise DatabaseOperationError(e) -# def hybrid_search( -# keyword_search: List[str], -# embed_search: List[str], -# results: int = 50, -# source: bool = True, -# note: bool = True, -# ): -# EMBEDDING_MODEL = model_manager.embedding_model -# embed1_vector = EMBEDDING_MODEL.embed(embed1) if embed1 else None - - -# todo: mover o embedding pra ca def vector_search(keyword: str, results: int, source: bool = True, note: bool = True): if not keyword: raise InvalidInputError("Search keyword cannot be empty") diff --git a/pages/2_πŸ“’_Notebooks.py b/pages/2_πŸ“’_Notebooks.py index f9eb191..f1e3354 100644 --- a/pages/2_πŸ“’_Notebooks.py +++ b/pages/2_πŸ“’_Notebooks.py @@ -10,11 +10,14 @@ from pages.stream_app.utils import setup_page, setup_stream_state setup_page("πŸ“’ Open Notebook") -def notebook_header(current_notebook): +def notebook_header(current_notebook: Notebook): + """ + Defines the header of the notebook page, including the ability to edit the notebook's name and description. + """ c1, c2, c3 = st.columns([8, 2, 2]) c1.header(current_notebook.name) if c2.button("Back to the list", icon="πŸ”™"): - st.session_state["current_notebook"] = None + st.session_state["current_notebook_id"] = None st.rerun() if c3.button("Refresh", icon="πŸ”„"): @@ -49,20 +52,20 @@ def notebook_header(current_notebook): st.toast("Notebook unarchived", icon="πŸ—ƒοΈ") if c3.button("Delete forever", type="primary", icon="☠️"): current_notebook.delete() - st.session_state["current_notebook"] = None + st.session_state["current_notebook_id"] = None st.rerun() -def notebook_page(current_notebook_id): - current_notebook: Notebook = Notebook.get(current_notebook_id) - if not current_notebook: - st.error("Notebook not found") - return - if current_notebook_id not in st.session_state.keys(): - st.session_state[current_notebook_id] = current_notebook +def notebook_page(current_notebook: Notebook): + # Guarantees that we have an entry for this notebook in the session state + if current_notebook.id not in st.session_state: + st.session_state[current_notebook.id] = {"notebook": current_notebook} + + # sets up the active session + current_session = setup_stream_state( + current_notebook=current_notebook, + ) - session_id = st.session_state["active_session"] - st.session_state[session_id]["notebook"] = current_notebook sources = current_notebook.sources notes = current_notebook.notes @@ -74,18 +77,18 @@ def notebook_page(current_notebook_id): with sources_tab: with st.container(border=True): if st.button("Add Source", icon="βž•"): - add_source(session_id) + add_source(current_notebook.id) for source in sources: - source_card(session_id=session_id, source=source) + source_card(source=source, notebook_id=current_notebook.id) with notes_tab: with st.container(border=True): if st.button("Write a Note", icon="πŸ“"): - add_note(session_id) + add_note(current_notebook.id) for note in notes: - note_card(session_id=session_id, note=note) + note_card(note=note, notebook_id=current_notebook.id) with chat_tab: - chat_sidebar(session_id=session_id) + chat_sidebar(current_notebook=current_notebook, current_session=current_session) def notebook_list_item(notebook): @@ -96,40 +99,50 @@ def notebook_list_item(notebook): ) st.write(notebook.description) if st.button("Open", key=f"open_notebook_{notebook.id}"): - setup_stream_state(notebook.id) - st.session_state["current_notebook"] = notebook.id + st.session_state["current_notebook_id"] = notebook.id st.rerun() -if "current_notebook" not in st.session_state: - st.session_state["current_notebook"] = None +if "current_notebook_id" not in st.session_state: + st.session_state["current_notebook_id"] = None -if st.session_state["current_notebook"]: - notebook_page(st.session_state["current_notebook"]) +# todo: get the notebook, check if it exists and if it's archived +if st.session_state["current_notebook_id"]: + current_notebook: Notebook = Notebook.get(st.session_state["current_notebook_id"]) + if not current_notebook: + st.error("Notebook not found") + st.stop() + notebook_page(current_notebook) st.stop() st.title("πŸ“’ My Notebooks") -st.caption("Here are all your notebooks") +st.caption( + "Notebooks are a great way to organize your thoughts, ideas, and sources. You can create notebooks for different research topics and projects, to create new articles, etc. " +) + +with st.expander("βž• **New Notebook**"): + new_notebook_title = st.text_input("New Notebook Name") + new_notebook_description = st.text_area( + "Description", + placeholder="Explain the purpose of this notebook. The more details the better.", + ) + if st.button("Create a new Notebook", icon="βž•"): + notebook = Notebook( + name=new_notebook_title, description=new_notebook_description + ) + notebook.save() + st.toast("Notebook created successfully", icon="πŸ“’") notebooks = Notebook.get_all(order_by="updated desc") +archived_notebooks = [nb for nb in notebooks if nb.archived] for notebook in notebooks: if notebook.archived: continue notebook_list_item(notebook) -with st.expander("βž• **New Notebook**"): - new_notebook_title = st.text_input("New Notebook Name") - new_notebook_description = st.text_area("Description") - if st.button("Create a new Notebook", icon="βž•"): - notebook = Notebook( - name=new_notebook_title, description=new_notebook_description - ) - notebook.save() - st.rerun() - -archived_notebooks = [nb for nb in notebooks if nb.archived] if len(archived_notebooks) > 0: with st.expander(f"**πŸ—ƒοΈ {len(archived_notebooks)} archived Notebooks**"): + st.write("β„Ή Archived Notebooks can still be accessed and used in search.") for notebook in archived_notebooks: notebook_list_item(notebook) diff --git a/pages/stream_app/chat.py b/pages/stream_app/chat.py index c3c2426..cb33852 100644 --- a/pages/stream_app/chat.py +++ b/pages/stream_app/chat.py @@ -1,19 +1,21 @@ +import humanize import streamlit as st from langchain_core.runnables import RunnableConfig -from open_notebook.domain.notebook import Note, Source +from open_notebook.domain.notebook import ChatSession, Note, Notebook, Source from open_notebook.graphs.chat import graph as chat_graph from open_notebook.plugins.podcasts import PodcastConfig from open_notebook.utils import token_count +from pages.stream_app.utils import create_session_for_notebook from .note import make_note_from_chat # todo: build a smarter, more robust context manager function -def build_context(session_id): - st.session_state[session_id]["context"] = dict(note=[], source=[]) +def build_context(notebook_id): + st.session_state[notebook_id]["context"] = dict(note=[], source=[]) - for id, status in st.session_state[session_id]["context_config"].items(): + for id, status in st.session_state[notebook_id]["context_config"].items(): if not id: continue @@ -24,6 +26,7 @@ def build_context(session_id): if "not in" in status: continue + # todo: there is problably a better way to handle this if item_type == "note": item: Note = Note.get(id) elif item_type == "source": @@ -34,30 +37,33 @@ def build_context(session_id): if not item: continue if "summary" in status: - st.session_state[session_id]["context"][item_type] += [ + st.session_state[notebook_id]["context"][item_type] += [ item.get_context(context_size="short") ] elif "full content" in status: - st.session_state[session_id]["context"][item_type] += [ + st.session_state[notebook_id]["context"][item_type] += [ item.get_context(context_size="long") ] - return st.session_state[session_id]["context"] + return st.session_state[notebook_id]["context"] -def execute_chat(txt_input, session_id): - current_state = st.session_state[session_id] +def execute_chat(txt_input, current_session): + current_state = st.session_state[current_session.id] current_state["messages"] += [txt_input] result = chat_graph.invoke( input=current_state, - config=RunnableConfig(configurable={"thread_id": session_id}), + config=RunnableConfig(configurable={"thread_id": current_session.id}), ) + current_session.save() return result -def chat_sidebar(session_id): - context = build_context(session_id=session_id) - tokens = token_count(str(context) + str(st.session_state[session_id]["messages"])) +def chat_sidebar(current_notebook: Notebook, current_session: ChatSession): + context = build_context(notebook_id=current_notebook.id) + tokens = token_count( + str(context) + str(st.session_state[current_session.id]["messages"]) + ) chat_tab, podcast_tab = st.tabs(["Chat", "Podcast"]) with st.expander(f"Context ({tokens} tokens), {len(str(context))} chars"): st.json(context) @@ -91,15 +97,64 @@ def chat_sidebar(session_id): st.success("Episode generated successfully") st.page_link("pages/5_πŸŽ™οΈ_Podcasts.py", label="πŸŽ™οΈ Go to Podcasts") with chat_tab: + with st.expander( + f"**Session:** {current_session.title} - {humanize.naturaltime(current_session.updated)}" + ): + new_session_name = st.text_input( + "Current Session", + key="new_session_name", + value=current_session.title, + ) + c1, c2 = st.columns(2) + if c1.button("Rename", key="rename_session"): + current_session.title = new_session_name + current_session.save() + st.rerun() + if c2.button("Delete", key="delete_session_1"): + current_session.delete() + st.session_state[current_notebook.id]["active_session"] = None + st.rerun() + st.divider() + new_session_name = st.text_input( + "New Session Name", + key="new_session_name_f", + placeholder="Enter a name for the new session...", + ) + st.caption("If no name provided, we'll use the current date.") + if st.button("Create New Session", key="create_new_session"): + new_session = create_session_for_notebook( + notebook_id=current_notebook.id, session_name=new_session_name + ) + st.session_state[current_notebook.id]["active_session"] = new_session.id + st.rerun() + st.divider() + sessions = current_notebook.chat_sessions + if len(sessions) > 1: + st.markdown("**Other Sessions:**") + for session in sessions: + if session.id == current_session.id: + continue + + st.markdown( + f"{session.title} - {humanize.naturaltime(session.updated)}" + ) + if st.button(label="Load", key=f"load_session_{session.id}"): + st.session_state[current_notebook.id]["active_session"] = ( + session.id + ) + st.rerun() with st.container(border=True): request = st.chat_input("Enter your question") # removing for now since it's not multi-model capable right now st.caption(f"Total tokens: {tokens}") if request: - response = execute_chat(txt_input=request, session_id=session_id) - st.session_state[session_id]["messages"] = response["messages"] + response = execute_chat( + txt_input=request, + current_session=current_session, + ) + st.session_state[current_session.id]["messages"] = response["messages"] - for msg in st.session_state[session_id]["messages"][::-1]: + for msg in st.session_state[current_session.id]["messages"][::-1]: if msg.type not in ["human", "ai"]: continue if not msg.content: @@ -111,6 +166,6 @@ def chat_sidebar(session_id): if st.button("πŸ’Ύ New Note", key=f"render_save_{msg.id}"): make_note_from_chat( content=msg.content, - notebook_id=st.session_state[session_id]["notebook"].id, + notebook_id=current_notebook.id, ) st.rerun() diff --git a/pages/stream_app/note.py b/pages/stream_app/note.py index f22e29c..0516a1c 100644 --- a/pages/stream_app/note.py +++ b/pages/stream_app/note.py @@ -1,3 +1,5 @@ +from typing import Optional + import streamlit as st from humanize import naturaltime from loguru import logger @@ -11,22 +13,20 @@ from .consts import context_icons @st.dialog("Write a Note", width="large") -def add_note(session_id): +def add_note(notebook_id): note_title = st.text_input("Title") note_content = st.text_area("Content") if st.button("Save", key="add_note"): logger.debug("Adding note") note = Note(title=note_title, content=note_content, note_type="human") note.save() - note.add_to_notebook(st.session_state[session_id]["notebook"].id) + note.add_to_notebook(notebook_id) st.rerun() @st.dialog("Add a Source", width="large") -def note_panel(session_id=None, note_id=None): - if note_id: - note: Note = Note.get(note_id) - else: +def note_panel(notebook_id=None, note: Optional[Note] = None): + if not note: note: Note = Note(note_type="human") t_preview, t_edit = st.tabs(["Preview", "Edit"]) @@ -38,13 +38,13 @@ def note_panel(session_id=None, note_id=None): note.content = st_monaco( value=note.content, height="600px", language="markdown" ) - if st.button("Save", key=f"pn_edit_note_{note_id}"): + if st.button("Save", key=f"pn_edit_note_{note.id or 'new'}"): logger.debug("Editing note") note.save() if not note.id: - note.add_to_notebook(st.session_state[session_id]["notebook"].id) + note.add_to_notebook(notebook_id) st.rerun() - if st.button("Delete", type="primary", key=f"delete_note_{note_id}"): + if st.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"): logger.debug("Deleting note") note.delete() st.rerun() @@ -70,7 +70,7 @@ def make_note_from_chat(content, notebook_id=None): st.rerun() -def note_card(session_id, note): +def note_card(note, notebook_id): if note.note_type == "human": icon = "🀡" else: @@ -88,9 +88,9 @@ def note_card(session_id, note): st.caption(f"Updated: {naturaltime(note.updated)}") if st.button("Expand", icon="πŸ“", key=f"edit_note_{note.id}"): - note_panel(session_id, note.id) + note_panel(notebook_id=notebook_id, note=note) - st.session_state[session_id]["context_config"][note.id] = context_state + st.session_state[notebook_id]["context_config"][note.id] = context_state def note_list_item(note_id, score=None): @@ -105,4 +105,4 @@ def note_list_item(note_id, score=None): ): st.write(note.content) if st.button("Edit Note", icon="πŸ“", key=f"x_edit_note_{note.id}"): - note_panel(note_id=note.id) + note_panel(note=note) diff --git a/pages/stream_app/source.py b/pages/stream_app/source.py index 25f880b..c249b36 100644 --- a/pages/stream_app/source.py +++ b/pages/stream_app/source.py @@ -95,6 +95,7 @@ def source_panel(source_id): if st.button( "Embed vectors", icon="🦾", + disabled=source.embedded_chunks > 0, help="This will generate your embedding vectors on the database for powerful search capabilities", ): source.vectorize() @@ -119,7 +120,7 @@ def source_panel(source_id): @st.dialog("Add a Source", width="large") -def add_source(session_id): +def add_source(notebook_id): source_link = None source_file = None source_text = None @@ -167,7 +168,7 @@ def add_source(session_id): title=result.get("title"), ) source.save() - source.add_to_notebook(st.session_state[session_id]["notebook"].id) + source.add_to_notebook(notebook_id) st.write("Summarizing...") generate_toc_and_title(source) except UnsupportedTypeException as e: @@ -188,7 +189,7 @@ def add_source(session_id): st.rerun() -def source_card(session_id, source): +def source_card(source, notebook_id): # todo: more descriptive icons icon = "πŸ”—" @@ -208,7 +209,7 @@ def source_card(session_id, source): if st.button("Expand", icon="πŸ“", key=source.id): source_panel(source.id) - st.session_state[session_id]["context_config"][source.id] = context_state + st.session_state[notebook_id]["context_config"][source.id] = context_state def source_list_item(source_id, score=None): diff --git a/pages/stream_app/utils.py b/pages/stream_app/utils.py index 88f8849..b790db0 100644 --- a/pages/stream_app/utils.py +++ b/pages/stream_app/utils.py @@ -1,8 +1,12 @@ +from datetime import datetime +from typing import List, Union + import streamlit as st from loguru import logger from open_notebook.database.migrate import MigrationManager from open_notebook.domain.models import model_manager +from open_notebook.domain.notebook import ChatSession, Notebook from open_notebook.graphs.chat import ThreadState, graph from open_notebook.utils import ( compare_versions, @@ -33,19 +37,65 @@ def version_sidebar(): ) -def setup_stream_state(session_id) -> None: +def create_session_for_notebook(notebook_id: str, session_name: str = None): + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + title = f"Chat Session {current_time}" if not session_name else session_name + chat_session = ChatSession(title=title) + chat_session.save() + chat_session.relate_to_notebook(notebook_id) + return chat_session + + +def setup_stream_state(current_notebook: Notebook) -> ChatSession: """ Sets the value of the current session_id for langgraph thread state. If there is no existing thread state for this session_id, it creates a new one. + Finally, it acquires the existing state for the session from Langgraph state and sets it in the streamlit session state. """ - existing_state = graph.get_state({"configurable": {"thread_id": session_id}}).values - if len(existing_state.keys()) == 0: - st.session_state[session_id] = ThreadState( + assert ( + current_notebook is not None and current_notebook.id + ), "Current Notebook not selected properly" + + if "context_config" not in st.session_state[current_notebook.id]: + st.session_state[current_notebook.id]["context_config"] = {} + + current_session_id = st.session_state[current_notebook.id].get("active_session") + + # gets the chat session if provided + chat_session: Union[ChatSession, None] = ( + ChatSession.get(current_session_id) if current_session_id else None + ) + + # if there is no chat session, create one or get the first one + if not chat_session: + sessions: List[ChatSession] = current_notebook.chat_sessions + if not sessions or len(sessions) == 0: + logger.debug("Creating new chat session") + chat_session = create_session_for_notebook(current_notebook.id) + else: + logger.debug("Getting last updated session") + chat_session = sessions[0] + + logger.debug(f"Chat session: {chat_session}") + + if not chat_session or chat_session.id is None: + raise ValueError("Problem acquiring chat session") + # sets the active session for the notebook + st.session_state[current_notebook.id]["active_session"] = chat_session.id + + # gets the existing state for the session from Langgraph state + existing_state = graph.get_state( + {"configurable": {"thread_id": chat_session.id}} + ).values + if not existing_state or len(existing_state.keys()) == 0: + st.session_state[chat_session.id] = ThreadState( messages=[], context=None, notebook=None, context_config={} ) else: - st.session_state[session_id] = existing_state - st.session_state["active_session"] = session_id + st.session_state[chat_session.id] = existing_state + + st.session_state[current_notebook.id]["active_session"] = chat_session.id + return chat_session def check_migration(): From 35c68dff11225e18dd7df10302ab74f6758a6ec2 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Tue, 5 Nov 2024 13:57:10 -0300 Subject: [PATCH 28/31] better readme --- README.md | 197 ++++++++++++++++++++++++++++++++++++------- docs/assets/hero.svg | 60 +++++++++++++ 2 files changed, 226 insertions(+), 31 deletions(-) create mode 100644 docs/assets/hero.svg diff --git a/README.md b/README.md index 7cfb8d2..3cd9175 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,100 @@ -# Open Notebook + + + +[![Forks][forks-shield]][forks-url] +[![Stargazers][stars-shield]][stars-url] +[![Issues][issues-shield]][issues-url] +[![MIT License][license-shield]][license-url] + + + + +
+
+ + Logo + + +

Open Notebook

+ +

+ An open source, privacy-focused alternative to Google's Notebook LM! +
+ Checkout our website Β» +
+
+ Setup + Β· + Usage + Β· + Transformations + Β· + Podcasts + +

+
+ + + + +
+ Table of Contents +
    +
  1. + About The Project + +
  2. +
  3. + Setting Up + +
  4. +
  5. Usage
  6. +
  7. Features
  8. +
  9. Roadmap
  10. +
  11. Contributing
  12. +
  13. License
  14. +
  15. Contact
  16. +
  17. Acknowledgments
  18. +
+
+ + + + +## About The Project + +![New Notebook](docs/assets/asset_list.png) + An open source, privacy-focused alternative to Google's Notebook LM. Why give Google more of our data when we can take control of our own research workflows? -In a world dominated by Artificial Intelligence, having the ability to think 🧠 and acquire new knowledge πŸ’‘, is a skill that should not be a privilege for a few, nor restricted to a single company. +In a world dominated by Artificial Intelligence, having the ability to think 🧠 and acquire new knowledge πŸ’‘, is a skill that should not be a privilege for a few, nor restricted to a single provider. Open Notebook empowers you to manage your research, generate AI-assisted notes, and interact with your contentβ€”on your terms. Learn more about our project at [https://www.open-notebook.ai](https://www.open-notebook.ai) + + +

(back to top)

+ + + +### Built With + +* [![Python][Python]][Python-url] +* [![SurrealDB][SurrealDB]][SurrealDB-url] +* [![LangChain][LangChain]][LangChain-url] +* [![Streamlit][Streamlit]][Streamlit-url] + +

(back to top)

+ + ## βš™οΈ Setting Up Go to the [Setup Guide](docs/SETUP.md) to learn how to set up the tool in details. @@ -44,15 +131,19 @@ volumes: ``` -## Usage Instructions +Take a look at the [Open Notebook Boilerplate](https://github.com/lfnovo/open-notebook-boilerplate) repo with a sample of how to set it up for maximum feature usability. + + +

(back to top)

+ +## Usage Go to the [Usage](docs/USAGE.md) page to learn how to use all features. +

(back to top)

+ ## Features -![New Notebook](docs/assets/asset_list.png) - - - **Multi-Notebook Support**: Organize your research across multiple notebooks effortlessly. - **Multi-model support**: Open AI, Anthropic, Gemini, Vertex AI, Open Router, Ollama. - **Podcast Generator**: Automatically convert your notes into a podcast format. @@ -61,6 +152,8 @@ Go to the [Usage](docs/USAGE.md) page to learn how to use all features. - **Integrated Search Engines**: Built-in full-text and vector search for faster information retrieval. - **Fine-Grained Context Management**: Choose exactly what to share with the AI to maintain control. +

(back to top)

+ ## πŸš€ New Features ### v0.0.7 - Model Management πŸ—‚οΈ @@ -130,30 +223,30 @@ Locate anything across your research with ease using full-text and vector-based Jinja based prompts that are easy to customize to your own preferences. +

(back to top)

-## 🌟 Roadmap + +## Roadmap -- **Enhanced Citations**: Improved layout and finer control for citations. -- **Better Embeddings & Summarization**: Smarter ways to distill information. -- **Multiple Chat Sessions**: Juggle different discussions within the same notebook. -- **Live Front-End Updates**: Real-time UI updates for a smoother experience. -- **Async Processing**: Faster UI through asynchronous content processing. -- **Cross-Notebook Sources and Notes**: Reuse research notes across projects. -- **Bookmark Integration**: Integrate with your favorite bookmarking app. -- **Multi-model support**: Open AI, Anthropic, Vertex AI, Open Router, Ollama, etc. βœ… 0.0.2 -- **Insight Generation**: New tools for creating insights - [transformations](docs/TRANSFORMATIONS.md) βœ… 0.0.3 -- **Podcast Generator**: Automatically convert your notes into a podcast format. βœ… 0.0.4 +- [ ] **Live Front-End Updates**: Real-time UI updates for a smoother experience. +- [ ] **Async Processing**: Faster UI through asynchronous content processing. +- [ ] **Cross-Notebook Sources and Notes**: Reuse research notes across projects. +- [ ] **Bookmark Integration**: Integrate with your favorite bookmarking app. +- βœ… **Multi-model support**: Open AI, Anthropic, Vertex AI, Open Router, Ollama, etc. +- βœ… **Insight Generation**: New tools for creating insights - [transformations](docs/TRANSFORMATIONS.md) +- βœ… **Podcast Generator**: Automatically convert your notes into a podcast format. +- βœ… **Multiple Chat Sessions**: Juggle different discussions within the same notebook. +- βœ… **Enhanced Citations**: Improved layout and finer control for citations. +- βœ… **Better Embeddings & Summarization**: Smarter ways to distill information. + +See the [open issues](https://github.com/lfnovo/open-notebook/issues) for a full list of proposed features (and known issues). + +

(back to top)

-## πŸ’» Tech Stack -- **Streamlit**: For the front-end (Looking to move out of Streamlit. Contributors welcome!). -- **SurrealDB**: Fast, scalable database solution. -- **Langchain/Langgraph**: The backbone for LLM interactions. -- **Podcastfy**: For generating podcasts from your notes. - - -## πŸ™Œ Help Wanted + +## Contributing We would love your contributions! Specifically, we're looking for help with: - **Front-End Development**: Improve the UI/UX by moving beyond Streamlit. @@ -161,16 +254,58 @@ We would love your contributions! Specifically, we're looking for help with: - **Feature Development**: Let’s make the coolest note-taking tool together! See more at [CONTRIBUTING](CONTRIBUTING.md) -## πŸ“„ License + +

(back to top)

+ + + +## License Open Notebook is MIT licensed. See the [LICENSE](LICENSE) file for details. ---- +

(back to top)

-Your contributions, feature requests, and bug reports are always welcome. Let's build a research tool that respects our privacy and makes learning truly open for everyone. ✨ ---- -This project uses the following third-party libraries: + +## Contact -- [Podcastfy](https://github.com/souzatharsis/podcastfy) - Licensed under the Apache License 2.0 \ No newline at end of file +Luis Novo - [@lfnovo](https://twitter.com/lfnovo) + +

(back to top)

+ + + + +## Acknowledgments + +This project uses some amazing third-party libraries + +* [Podcastfy](https://github.com/souzatharsis/podcastfy) - Licensed under the Apache License 2.0 + +

(back to top)

+ + + + +[contributors-shield]: https://img.shields.io/github/contributors/lfnovo/open-notebook.svg?style=for-the-badge +[contributors-url]: https://github.com/lfnovo/open-notebook/graphs/contributors +[forks-shield]: https://img.shields.io/github/forks/lfnovo/open-notebook.svg?style=for-the-badge +[forks-url]: https://github.com/lfnovo/open-notebook/network/members +[stars-shield]: https://img.shields.io/github/stars/lfnovo/open-notebook.svg?style=for-the-badge +[stars-url]: https://github.com/lfnovo/open-notebook/stargazers +[issues-shield]: https://img.shields.io/github/issues/lfnovo/open-notebook.svg?style=for-the-badge +[issues-url]: https://github.com/lfnovo/open-notebook/issues +[license-shield]: https://img.shields.io/github/license/lfnovo/open-notebook.svg?style=for-the-badge +[license-url]: https://github.com/lfnovo/open-notebook/blob/master/LICENSE.txt +[linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=for-the-badge&logo=linkedin&colorB=555 +[linkedin-url]: https://linkedin.com/in/lfnovo +[product-screenshot]: images/screenshot.png +[Streamlit]: https://img.shields.io/badge/Streamlit-FF4B4B?style=for-the-badge&logo=streamlit&logoColor=white +[Streamlit-url]: https://streamlit.io/ +[Python]: https://img.shields.io/badge/Python-3776AB?style=for-the-badge&logo=python&logoColor=white +[Python-url]: https://www.python.org/ +[LangChain]: https://img.shields.io/badge/LangChain-3A3A3A?style=for-the-badge&logo=chainlink&logoColor=white +[LangChain-url]: https://www.langchain.com/ +[SurrealDB]: https://img.shields.io/badge/SurrealDB-FF5E00?style=for-the-badge&logo=databricks&logoColor=white +[SurrealDB-url]: https://surrealdb.com/ diff --git a/docs/assets/hero.svg b/docs/assets/hero.svg new file mode 100644 index 0000000..8701347 --- /dev/null +++ b/docs/assets/hero.svg @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file From 3ea4e41a782f4feaa3fd9c5f55ac596fc4c0f538 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Tue, 5 Nov 2024 16:55:59 -0300 Subject: [PATCH 29/31] improve citations and add object page --- app_home.py | 34 +++++++- open_notebook/domain/base.py | 54 +++++++++++-- open_notebook/domain/notebook.py | 34 +++++++- pages/3_πŸ”_Ask_and_Search.py | 4 +- pages/components/__init__.py | 11 +++ pages/components/note_panel.py | 30 ++++++++ pages/components/source_embedding_panel.py | 17 ++++ pages/components/source_insight.py | 18 +++++ pages/components/source_panel.py | 84 ++++++++++++++++++++ pages/stream_app/chat.py | 23 +++--- pages/stream_app/note.py | 31 ++------ pages/stream_app/source.py | 90 ++-------------------- pages/stream_app/utils.py | 42 +++++++++- prompts/chat.jinja | 28 ++++++- 14 files changed, 362 insertions(+), 138 deletions(-) create mode 100644 pages/components/__init__.py create mode 100644 pages/components/note_panel.py create mode 100644 pages/components/source_embedding_panel.py create mode 100644 pages/components/source_insight.py create mode 100644 pages/components/source_panel.py diff --git a/app_home.py b/app_home.py index 4badc36..a8df6c9 100644 --- a/app_home.py +++ b/app_home.py @@ -1,3 +1,35 @@ import streamlit as st -st.switch_page("pages/2_πŸ“’_Notebooks.py") +from open_notebook.domain.base import ObjectModel +from open_notebook.exceptions import NotFoundError +from pages.components import ( + note_panel, + source_embedding_panel, + source_insight_panel, + source_panel, +) +from pages.stream_app.utils import setup_page + +setup_page("πŸ“’ Open Notebook", sidebar_state="collapsed") + +if "object_id" not in st.query_params: + st.switch_page("pages/2_πŸ“’_Notebooks.py") + st.stop() + +object_id = st.query_params["object_id"] +try: + obj = ObjectModel.get(object_id) +except NotFoundError: + st.switch_page("pages/2_πŸ“’_Notebooks.py") + st.stop() + +obj_type = object_id.split(":")[0] + +if obj_type == "note": + note_panel(object_id) +elif obj_type == "source": + source_panel(object_id) +elif obj_type == "source_insight": + source_insight_panel(object_id) +elif obj_type == "source_embedding": + source_embedding_panel(object_id) diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index 146589d..8514e89 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar +from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, cast from loguru import logger from pydantic import BaseModel, ValidationError, field_validator @@ -29,15 +29,26 @@ class ObjectModel(BaseModel): @classmethod def get_all(cls: Type[T], order_by=None) -> List[T]: try: + # If called from a specific subclass, use its table_name + if cls.table_name: + target_class = cls + table_name = cls.table_name + else: + # This path is taken if called directly from ObjectModel + raise InvalidInputError( + "get_all() must be called from a specific model class" + ) + if order_by: order = f" ORDER BY {order_by}" else: order = "" - result = repo_query(f"SELECT * FROM {cls.table_name} {order}") + + result = repo_query(f"SELECT * FROM {table_name} {order}") objects = [] for obj in result: try: - objects.append(cls(**obj)) + objects.append(target_class(**obj)) except Exception as e: logger.critical(f"Error creating object: {str(e)}") @@ -52,15 +63,44 @@ class ObjectModel(BaseModel): if not id: raise InvalidInputError("ID cannot be empty") try: + # Get the table name from the ID (everything before the first colon) + table_name = id.split(":")[0] if ":" in id else id + + # If we're calling from a specific subclass and IDs match, use that class + if cls.table_name and cls.table_name == table_name: + target_class: Type[T] = cls + else: + # Otherwise, find the appropriate subclass based on table_name + found_class = cls._get_class_by_table_name(table_name) + if not found_class: + raise InvalidInputError(f"No class found for table {table_name}") + target_class = cast(Type[T], found_class) + result = repo_query(f"SELECT * FROM {id}") if result: - return cls(**result[0]) + return target_class(**result[0]) else: - raise NotFoundError(f"{cls.table_name} with id {id} not found") + raise NotFoundError(f"{table_name} with id {id} not found") except Exception as e: - logger.error(f"Error fetching {cls.table_name} with id {id}: {str(e)}") + logger.error(f"Error fetching object with id {id}: {str(e)}") logger.exception(e) - raise NotFoundError(f"{cls.table_name} with id {id} not found") + raise NotFoundError(f"Object with id {id} not found - {str(e)}") + + @classmethod + def _get_class_by_table_name(cls, table_name: str) -> Optional[Type["ObjectModel"]]: + """Find the appropriate subclass based on table_name.""" + + def get_all_subclasses(c: Type["ObjectModel"]) -> List[Type["ObjectModel"]]: + all_subclasses: List[Type["ObjectModel"]] = [] + for subclass in c.__subclasses__(): + all_subclasses.append(subclass) + all_subclasses.extend(get_all_subclasses(subclass)) + return all_subclasses + + for subclass in get_all_subclasses(ObjectModel): + if hasattr(subclass, "table_name") and subclass.table_name == table_name: + return subclass + return None def needs_embedding(self) -> bool: return False diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index 86944f8..e473193 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -93,10 +93,42 @@ class Asset(BaseModel): url: Optional[str] = None +class SourceEmbedding(ObjectModel): + table_name: ClassVar[str] = "source_embedding" + content: str + + @property + def source(self) -> "Source": + try: + src = repo_query(f""" + select source.* from {self.id} fetch source + + """) + return Source(**src[0]["source"]) + except Exception as e: + logger.error(f"Error fetching source for embedding {self.id}: {str(e)}") + logger.exception(e) + raise DatabaseOperationError(e) + + class SourceInsight(ObjectModel): + table_name: ClassVar[str] = "source_insight" insight_type: str content: str + @property + def source(self) -> "Source": + try: + src = repo_query(f""" + select source.* from {self.id} fetch source + + """) + return Source(**src[0]["source"]) + except Exception as e: + logger.error(f"Error fetching source for insight {self.id}: {str(e)}") + logger.exception(e) + raise DatabaseOperationError(e) + class Source(ObjectModel): table_name: ClassVar[str] = "source" @@ -112,7 +144,7 @@ class Source(ObjectModel): return dict( id=self.id, title=self.title, - insights=self.insights, + insights=[insight.model_dump() for insight in self.insights], full_text=self.full_text, ) else: diff --git a/pages/3_πŸ”_Ask_and_Search.py b/pages/3_πŸ”_Ask_and_Search.py index d70a3e9..964f208 100644 --- a/pages/3_πŸ”_Ask_and_Search.py +++ b/pages/3_πŸ”_Ask_and_Search.py @@ -3,7 +3,7 @@ import streamlit as st from open_notebook.domain.models import Model from open_notebook.domain.notebook import text_search, vector_search from open_notebook.graphs.rag import graph as rag_graph -from pages.stream_app.utils import setup_page +from pages.stream_app.utils import convert_source_references, setup_page setup_page("πŸ” Search") @@ -40,7 +40,7 @@ with ask_tab: messages=messages ), # config=dict(configurable=dict(model_id=model.id)) ) - st.markdown(rag_results["messages"][-1].content) + st.markdown(convert_source_references(rag_results["messages"][-1].content)) with st.expander("Details (for debugging)"): st.json(rag_results) diff --git a/pages/components/__init__.py b/pages/components/__init__.py new file mode 100644 index 0000000..d6561b5 --- /dev/null +++ b/pages/components/__init__.py @@ -0,0 +1,11 @@ +from pages.components.note_panel import note_panel +from pages.components.source_embedding_panel import source_embedding_panel +from pages.components.source_insight import source_insight_panel +from pages.components.source_panel import source_panel + +__all__ = [ + "note_panel", + "source_embedding_panel", + "source_insight_panel", + "source_panel", +] diff --git a/pages/components/note_panel.py b/pages/components/note_panel.py new file mode 100644 index 0000000..d1cba31 --- /dev/null +++ b/pages/components/note_panel.py @@ -0,0 +1,30 @@ +import streamlit as st +from loguru import logger +from streamlit_monaco import st_monaco # type: ignore + +from open_notebook.domain.notebook import Note + + +def note_panel(note_id, notebook_id=None): + note: Note = Note.get(note_id) + if not note: + raise ValueError(f"Note not fonud {note_id}") + t_preview, t_edit = st.tabs(["Preview", "Edit"]) + with t_preview: + st.subheader(note.title) + st.markdown(note.content) + with t_edit: + note.title = st.text_input("Title", value=note.title) + note.content = st_monaco( + value=note.content, height="600px", language="markdown" + ) + if st.button("Save", key=f"pn_edit_note_{note.id or 'new'}"): + logger.debug("Editing note") + note.save() + if not note.id and notebook_id: + note.add_to_notebook(notebook_id) + st.rerun() + if st.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"): + logger.debug("Deleting note") + note.delete() + st.rerun() diff --git a/pages/components/source_embedding_panel.py b/pages/components/source_embedding_panel.py new file mode 100644 index 0000000..4ef4a29 --- /dev/null +++ b/pages/components/source_embedding_panel.py @@ -0,0 +1,17 @@ +import streamlit as st + +from open_notebook.domain.notebook import SourceEmbedding + + +def source_embedding_panel(source_embedding_id): + si: SourceEmbedding = SourceEmbedding.get(source_embedding_id) + if not si: + raise ValueError(f"Embedding not found {source_embedding_id}") + with st.container(border=True): + url = f"Navigator?object_id={si.source.id}" + st.markdown("**Original Source**") + st.markdown(f"{si.source.title} [link](%s)" % url) + st.markdown(si.content) + if st.button("Delete", type="primary", key=f"delete_embedding_{si.id or 'new'}"): + si.delete() + st.rerun() diff --git a/pages/components/source_insight.py b/pages/components/source_insight.py new file mode 100644 index 0000000..ad38793 --- /dev/null +++ b/pages/components/source_insight.py @@ -0,0 +1,18 @@ +import streamlit as st + +from open_notebook.domain.notebook import SourceInsight + + +def source_insight_panel(source, notebook_id=None): + si: SourceInsight = SourceInsight.get(source) + if not si: + raise ValueError(f"insight not found {source}") + st.subheader(si.insight_type) + with st.container(border=True): + url = f"Navigator?object_id={si.source.id}" + st.markdown("**Original Source**") + st.markdown(f"{si.source.title} [link](%s)" % url) + st.markdown(si.content) + if st.button("Delete", type="primary", key=f"delete_insight_{si.id or 'new'}"): + si.delete() + st.rerun() diff --git a/pages/components/source_panel.py b/pages/components/source_panel.py new file mode 100644 index 0000000..8fa130d --- /dev/null +++ b/pages/components/source_panel.py @@ -0,0 +1,84 @@ +import streamlit as st +import streamlit_scrollable_textbox as stx # type: ignore +import yaml +from humanize import naturaltime + +from open_notebook.domain.notebook import Source +from open_notebook.utils import surreal_clean +from pages.stream_app.utils import run_patterns + + +def source_panel(source_id: str, modal=False): + source: Source = Source.get(source_id) + if not source: + raise ValueError(f"Source not found: {source_id}") + + current_title = source.title if source.title else "No Title" + source.title = st.text_input("Title", value=current_title) + if source.title != current_title: + st.toast("Saved new Title") + source.save() + + process_tab, source_tab = st.tabs(["Process", "Source"]) + with process_tab: + c1, c2 = st.columns([3, 1]) + with c1: + title = st.empty() + if source.title: + title.subheader(source.title) + if source.asset and source.asset.url: + from_src = f"from URL: {source.asset.url}" + elif source.asset and source.asset.file_path: + from_src = f"from file: {source.asset.file_path}" + else: + from_src = "from text" + st.caption(f"Created {naturaltime(source.created)}, {from_src}") + for insight in source.insights: + with st.expander(f"**{insight.insight_type}**"): + st.markdown(insight.content) + if st.button( + "Delete", type="primary", key=f"delete_insight_{insight.id}" + ): + insight.delete() + st.rerun(scope="fragment" if modal else "app") + + with c2: + with open("transformations.yaml", "r") as file: + transformations = yaml.safe_load(file) + for transformation in transformations["source_insights"]: + if st.button( + transformation["name"], help=transformation["description"] + ): + result = run_patterns( + source.full_text, transformation["patterns"] + ) + source.add_insight( + transformation["insight_type"], surreal_clean(result) + ) + st.rerun(scope="fragment" if modal else "app") + + if st.button( + "Embed vectors", + icon="🦾", + disabled=source.embedded_chunks > 0, + help="This will generate your embedding vectors on the database for powerful search capabilities", + ): + source.vectorize() + st.success("Embedding complete") + + chk_delete = st.checkbox( + "πŸ—‘οΈ Delete source", key=f"delete_source_{source.id}", value=False + ) + if chk_delete: + st.warning( + "Source will be deleted with all its insights and embeddings" + ) + if st.button( + "Delete", type="primary", key=f"bt_delete_source_{source.id}" + ): + source.delete() + st.rerun() + + with source_tab: + st.subheader("Content") + stx.scrollableTextbox(source.full_text, height=300) diff --git a/pages/stream_app/chat.py b/pages/stream_app/chat.py index cb33852..71c0437 100644 --- a/pages/stream_app/chat.py +++ b/pages/stream_app/chat.py @@ -1,12 +1,18 @@ +from typing import Union + import humanize import streamlit as st from langchain_core.runnables import RunnableConfig +from open_notebook.domain.base import ObjectModel from open_notebook.domain.notebook import ChatSession, Note, Notebook, Source from open_notebook.graphs.chat import graph as chat_graph from open_notebook.plugins.podcasts import PodcastConfig from open_notebook.utils import token_count -from pages.stream_app.utils import create_session_for_notebook +from pages.stream_app.utils import ( + convert_source_references, + create_session_for_notebook, +) from .note import make_note_from_chat @@ -26,13 +32,7 @@ def build_context(notebook_id): if "not in" in status: continue - # todo: there is problably a better way to handle this - if item_type == "note": - item: Note = Note.get(id) - elif item_type == "source": - item: Source = Source.get(id) - else: - continue + item: Union[Note, Source] = ObjectModel.get(id) if not item: continue @@ -48,9 +48,10 @@ def build_context(notebook_id): return st.session_state[notebook_id]["context"] -def execute_chat(txt_input, current_session): +def execute_chat(txt_input, context, current_session): current_state = st.session_state[current_session.id] current_state["messages"] += [txt_input] + current_state["context"] = context result = chat_graph.invoke( input=current_state, config=RunnableConfig(configurable={"thread_id": current_session.id}), @@ -146,10 +147,10 @@ def chat_sidebar(current_notebook: Notebook, current_session: ChatSession): with st.container(border=True): request = st.chat_input("Enter your question") # removing for now since it's not multi-model capable right now - st.caption(f"Total tokens: {tokens}") if request: response = execute_chat( txt_input=request, + context=context, current_session=current_session, ) st.session_state[current_session.id]["messages"] = response["messages"] @@ -161,7 +162,7 @@ def chat_sidebar(current_notebook: Notebook, current_session: ChatSession): continue with st.chat_message(name=msg.type): - st.write(msg.content) + st.markdown(convert_source_references(msg.content)) if msg.type == "ai": if st.button("πŸ’Ύ New Note", key=f"render_save_{msg.id}"): make_note_from_chat( diff --git a/pages/stream_app/note.py b/pages/stream_app/note.py index 0516a1c..3dc96d4 100644 --- a/pages/stream_app/note.py +++ b/pages/stream_app/note.py @@ -3,11 +3,11 @@ from typing import Optional import streamlit as st from humanize import naturaltime from loguru import logger -from streamlit_monaco import st_monaco # type: ignore from open_notebook.domain.notebook import Note from open_notebook.graphs.multipattern import graph as pattern_graph from open_notebook.utils import surreal_clean +from pages.components import note_panel from .consts import context_icons @@ -25,29 +25,8 @@ def add_note(notebook_id): @st.dialog("Add a Source", width="large") -def note_panel(notebook_id=None, note: Optional[Note] = None): - if not note: - note: Note = Note(note_type="human") - - t_preview, t_edit = st.tabs(["Preview", "Edit"]) - with t_preview: - st.subheader(note.title) - st.markdown(note.content) - with t_edit: - note.title = st.text_input("Title", value=note.title) - note.content = st_monaco( - value=note.content, height="600px", language="markdown" - ) - if st.button("Save", key=f"pn_edit_note_{note.id or 'new'}"): - logger.debug("Editing note") - note.save() - if not note.id: - note.add_to_notebook(notebook_id) - st.rerun() - if st.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"): - logger.debug("Deleting note") - note.delete() - st.rerun() +def note_panel_dialog(note: Optional[Note] = None, notebook_id=None): + note_panel(note_id=note.id, notebook_id=notebook_id) def make_note_from_chat(content, notebook_id=None): @@ -88,7 +67,7 @@ def note_card(note, notebook_id): st.caption(f"Updated: {naturaltime(note.updated)}") if st.button("Expand", icon="πŸ“", key=f"edit_note_{note.id}"): - note_panel(notebook_id=notebook_id, note=note) + note_panel_dialog(notebook_id=notebook_id, note=note) st.session_state[notebook_id]["context_config"][note.id] = context_state @@ -105,4 +84,4 @@ def note_list_item(note_id, score=None): ): st.write(note.content) if st.button("Edit Note", icon="πŸ“", key=f"x_edit_note_{note.id}"): - note_panel(note=note) + note_panel_dialog(note=note) diff --git a/pages/stream_app/source.py b/pages/stream_app/source.py index c249b36..fdc56c5 100644 --- a/pages/stream_app/source.py +++ b/pages/stream_app/source.py @@ -2,8 +2,6 @@ import os from pathlib import Path import streamlit as st -import streamlit_scrollable_textbox as stx # type: ignore -import yaml from humanize import naturaltime from loguru import logger @@ -11,17 +9,13 @@ from open_notebook.config import UPLOADS_FOLDER from open_notebook.domain.notebook import Asset, Source from open_notebook.exceptions import UnsupportedTypeException from open_notebook.graphs.content_processing import graph -from open_notebook.graphs.multipattern import graph as transform_graph from open_notebook.utils import surreal_clean +from pages.components import source_panel +from pages.stream_app.utils import run_patterns from .consts import context_icons -def run_patterns(input_text, patterns): - output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns)) - return output["output"] - - # moved it here to replace it with the pipeline on 0.1.0 def generate_toc_and_title(source) -> "Source": try: @@ -43,80 +37,8 @@ def generate_toc_and_title(source) -> "Source": @st.dialog("Source", width="large") -def source_panel(source_id): - source: Source = Source.get(source_id) - if not source: - st.error("Source not found") - return - current_title = source.title if source.title else "No Title" - source.title = st.text_input("Title", value=current_title) - if source.title != current_title: - st.toast("Saved new Title") - source.save() - - process_tab, source_tab = st.tabs(["Process", "Source"]) - with process_tab: - c1, c2 = st.columns([3, 1]) - with c1: - title = st.empty() - if source.title: - title.subheader(source.title) - if source.asset.url: - from_src = f"from URL: {source.asset.url}" - elif source.asset.file_path: - from_src = f"from file: {source.asset.file_path}" - else: - from_src = "from text" - st.caption(f"Created {naturaltime(source.created)}, {from_src}") - for insight in source.insights: - with st.expander(f"**{insight.insight_type}**"): - st.markdown(insight.content) - if st.button( - "Delete", type="primary", key=f"delete_insight_{insight.id}" - ): - insight.delete() - st.rerun(scope="fragment") - - with c2: - with open("transformations.yaml", "r") as file: - transformations = yaml.safe_load(file) - for transformation in transformations["source_insights"]: - if st.button( - transformation["name"], help=transformation["description"] - ): - result = run_patterns( - source.full_text, transformation["patterns"] - ) - source.add_insight( - transformation["insight_type"], surreal_clean(result) - ) - st.rerun(scope="fragment") - - if st.button( - "Embed vectors", - icon="🦾", - disabled=source.embedded_chunks > 0, - help="This will generate your embedding vectors on the database for powerful search capabilities", - ): - source.vectorize() - st.success("Embedding complete") - - chk_delete = st.checkbox( - "πŸ—‘οΈ Delete source", key=f"delete_source_{source.id}", value=False - ) - if chk_delete: - st.warning( - "Source will be deleted with all its insights and embeddings" - ) - if st.button( - "Delete", type="primary", key=f"bt_delete_source_{source.id}" - ): - source.delete() - st.rerun() - - with source_tab: - st.subheader("Content") - stx.scrollableTextbox(source.full_text, height=300) +def source_panel_dialog(source_id): + source_panel(source_id) @st.dialog("Add a Source", width="large") @@ -207,7 +129,7 @@ def source_card(source, notebook_id): f"Updated: {naturaltime(source.updated)}, **{len(source.insights)}** insights" ) if st.button("Expand", icon="πŸ“", key=source.id): - source_panel(source.id) + source_panel_dialog(source.id) st.session_state[notebook_id]["context_config"][source.id] = context_state @@ -226,4 +148,4 @@ def source_list_item(source_id, score=None): st.markdown(f"**{insight.insight_type}**") st.write(insight.content) if st.button("Edit source", icon="πŸ“", key=f"x_edit_source_{source.id}"): - source_panel(source_id=source.id) + source_panel_dialog(source_id=source.id) diff --git a/pages/stream_app/utils.py b/pages/stream_app/utils.py index b790db0..2caae54 100644 --- a/pages/stream_app/utils.py +++ b/pages/stream_app/utils.py @@ -1,3 +1,4 @@ +import re from datetime import datetime from typing import List, Union @@ -8,6 +9,7 @@ from open_notebook.database.migrate import MigrationManager from open_notebook.domain.models import model_manager from open_notebook.domain.notebook import ChatSession, Notebook from open_notebook.graphs.chat import ThreadState, graph +from open_notebook.graphs.multipattern import graph as transform_graph from open_notebook.utils import ( compare_versions, get_installed_version, @@ -15,6 +17,11 @@ from open_notebook.utils import ( ) +def run_patterns(input_text, patterns): + output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns)) + return output["output"] + + def version_sidebar(): with st.sidebar: try: @@ -76,8 +83,6 @@ def setup_stream_state(current_notebook: Notebook) -> ChatSession: logger.debug("Getting last updated session") chat_session = sessions[0] - logger.debug(f"Chat session: {chat_session}") - if not chat_session or chat_session.id is None: raise ValueError("Problem acquiring chat session") # sets the active session for the notebook @@ -163,3 +168,36 @@ def setup_page(title: str, layout="wide", sidebar_state="expanded"): check_migration() check_models() version_sidebar() + + +def convert_source_references(text): + """ + Converts source references in brackets to markdown-style links. + + Matches patterns like [source_insight:id], [note:id], [source:id], or [source_embedding:id] + and converts them to markdown links. + + Args: + text (str): The input text containing source references + + Returns: + str: Text with source references converted to markdown links + + Example: + >>> text = "Here is a reference [source_insight:abc123]" + >>> convert_source_references(text) + 'Here is a reference [source_insight:abc123](/?object_id=source_insight:abc123)' + """ + + # Pattern matches [type:id] where type can be source_insight, note, source, or source_embedding + pattern = r"\[((?:source_insight|note|source|source_embedding):[\w\d]+)\]" + + def replace_match(match): + """Helper function to create the markdown link""" + source_ref = match.group(1) # Gets the content inside brackets + return f"[[{source_ref}]](/?object_id={source_ref})" + + # Replace all matches in the text + converted_text = re.sub(pattern, replace_match, text) + + return converted_text diff --git a/prompts/chat.jinja b/prompts/chat.jinja index 0ba3471..dab2eda 100644 --- a/prompts/chat.jinja +++ b/prompts/chat.jinja @@ -5,9 +5,8 @@ You are a cognitive study assistant that helps users research and learn by engag - Access to project information and selected documents (CONTEXT) - Can engage in natural dialogue while maintaining academic rigor -# FORMULATE YOUR DATA -- Generate your answer based on the CONTEXT information -- Ensure that your response is accurate and relevant to the user's query +# YOUR OPERATING METHOD +Whenever a user asks you a question, you need to identify the query context and the user intent. The user might be continuing a previous conversation or asking a new question. Looking at the CONTEXT will probably give you a hint of what the user is looking for. Once you identify the user intent, formulate your answer accordingly paying attention to the CITING INSTRUCTIONS below. {% if notebook %} # PROJECT INFORMATION @@ -18,5 +17,26 @@ You are a cognitive study assistant that helps users research and learn by engag {% if context %} # CONTEXT +The user has selected this context to help you with your response: + {{context}} -{% endif %} \ No newline at end of file +{% endif %} + +# CITING INSTRUCTIONS + +If your answer is based off of any item in the context, it's very important that your response contains references to the searched documents so the user can follow-up and read more about the topic. The way you do that is by adding the id of the specific document in between brackets like this: [document_id]. + +## EXAMPLE + +User: Can you tell me more about the concept of "Deep Learning"? + +Assistant: Deep learning is a subset of machine learning in artificial intelligence (AI) that enables networks to learn unsupervised from unstructured or unlabeled data. [note:iuiodadalknda]. It can also be categorized into three main types: supervised, unsupervised, and reinforcement learning. [insight:adadadadadadad]. + +Please note, "note:iuiodadalknda" and "insight:adadadadadadad" are examples of document IDs with different prefixes. You should not make up document IDs or copy the IDs from this example. You should use the IDs of the documents that you have access to through the search tool. + +## IMPORTANT + +- Do not make up documents or document ids. Only use the ids of the documents that you have access through the query you made. +- The ID is composed of the type of document and a random string, such as "source:randomstring", "note:randomstring", or "insight:randomstring". There are various types of documents, including notes, insights, and sources. **Always use the complete ID exactly as it is provided, including its type prefix. Do not add, remove, or modify any part of the ID.** +- Do not assume or change the type prefix of any document ID. If a document ID is "note:xyz", use it exactly as "note:xyz". Do not change it to "source:xyz" or any other variation. +- **Use document IDs exactly as they are returned from the search tool. Do not add any prefixes or modify them in any way.** From 5aab6bdb69667d52f3bc3d1b73c85780b561973d Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Wed, 6 Nov 2024 13:22:23 -0300 Subject: [PATCH 30/31] custom model in search --- pages/3_πŸ”_Ask_and_Search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pages/3_πŸ”_Ask_and_Search.py b/pages/3_πŸ”_Ask_and_Search.py index 964f208..c085867 100644 --- a/pages/3_πŸ”_Ask_and_Search.py +++ b/pages/3_πŸ”_Ask_and_Search.py @@ -37,8 +37,9 @@ with ask_tab: messages = [question] rag_results = rag_graph.invoke( dict( - messages=messages - ), # config=dict(configurable=dict(model_id=model.id)) + messages=messages, + ), + config=dict(configurable=dict(model_id=model.id)), ) st.markdown(convert_source_references(rag_results["messages"][-1].content)) with st.expander("Details (for debugging)"): From 64c123c1304fc1287eb6d644c9d659594c926e18 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Wed, 6 Nov 2024 13:32:39 -0300 Subject: [PATCH 31/31] version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 811bec9..26f5b06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "open-notebook" -version = "0.0.8" +version = "0.0.9" description = "An open source implementation of a research assistant, inspired by Google Notebook LM" authors = ["Luis Novo "] license = "MIT"