+
+
+
+
+
+
+## About The Project
+
+
+
An open source, privacy-focused alternative to Google's Notebook LM. Why give Google more of our data when we can take control of our own research workflows?
-In a world dominated by Artificial Intelligence, having the ability to think π§ and acquire new knowledge π‘, is a skill that should not be a privilege for a few, nor restricted to a single company.
+In a world dominated by Artificial Intelligence, having the ability to think π§ and acquire new knowledge π‘, is a skill that should not be a privilege for a few, nor restricted to a single provider.
Open Notebook empowers you to manage your research, generate AI-assisted notes, and interact with your contentβon your terms.
Learn more about our project at [https://www.open-notebook.ai](https://www.open-notebook.ai)
+
+
+
+
+
## βοΈ Setting Up
Go to the [Setup Guide](docs/SETUP.md) to learn how to set up the tool in details.
@@ -44,15 +131,19 @@ volumes:
```
-## Usage Instructions
+Take a look at the [Open Notebook Boilerplate](https://github.com/lfnovo/open-notebook-boilerplate) repo with a sample of how to set it up for maximum feature usability.
+
+
+
+
## Features
-
-
-
- **Multi-Notebook Support**: Organize your research across multiple notebooks effortlessly.
- **Multi-model support**: Open AI, Anthropic, Gemini, Vertex AI, Open Router, Ollama.
- **Podcast Generator**: Automatically convert your notes into a podcast format.
@@ -61,6 +152,8 @@ Go to the [Usage](docs/USAGE.md) page to learn how to use all features.
- **Integrated Search Engines**: Built-in full-text and vector search for faster information retrieval.
- **Fine-Grained Context Management**: Choose exactly what to share with the AI to maintain control.
+
+
## π New Features
### v0.0.7 - Model Management ποΈ
@@ -130,30 +223,30 @@ Locate anything across your research with ease using full-text and vector-based
Jinja based prompts that are easy to customize to your own preferences.
+
-## π Roadmap
+
+## Roadmap
-- **Enhanced Citations**: Improved layout and finer control for citations.
-- **Better Embeddings & Summarization**: Smarter ways to distill information.
-- **Multiple Chat Sessions**: Juggle different discussions within the same notebook.
-- **Live Front-End Updates**: Real-time UI updates for a smoother experience.
-- **Async Processing**: Faster UI through asynchronous content processing.
-- **Cross-Notebook Sources and Notes**: Reuse research notes across projects.
-- **Bookmark Integration**: Integrate with your favorite bookmarking app.
-- **Multi-model support**: Open AI, Anthropic, Vertex AI, Open Router, Ollama, etc. β 0.0.2
-- **Insight Generation**: New tools for creating insights - [transformations](docs/TRANSFORMATIONS.md) β 0.0.3
-- **Podcast Generator**: Automatically convert your notes into a podcast format. β 0.0.4
+- [ ] **Live Front-End Updates**: Real-time UI updates for a smoother experience.
+- [ ] **Async Processing**: Faster UI through asynchronous content processing.
+- [ ] **Cross-Notebook Sources and Notes**: Reuse research notes across projects.
+- [ ] **Bookmark Integration**: Integrate with your favorite bookmarking app.
+- β **Multi-model support**: Open AI, Anthropic, Vertex AI, Open Router, Ollama, etc.
+- β **Insight Generation**: New tools for creating insights - [transformations](docs/TRANSFORMATIONS.md)
+- β **Podcast Generator**: Automatically convert your notes into a podcast format.
+- β **Multiple Chat Sessions**: Juggle different discussions within the same notebook.
+- β **Enhanced Citations**: Improved layout and finer control for citations.
+- β **Better Embeddings & Summarization**: Smarter ways to distill information.
+
+See the [open issues](https://github.com/lfnovo/open-notebook/issues) for a full list of proposed features (and known issues).
+
+
-## π» Tech Stack
-- **Streamlit**: For the front-end (Looking to move out of Streamlit. Contributors welcome!).
-- **SurrealDB**: Fast, scalable database solution.
-- **Langchain/Langgraph**: The backbone for LLM interactions.
-- **Podcastfy**: For generating podcasts from your notes.
-
-
-## π Help Wanted
+
+## Contributing
We would love your contributions! Specifically, we're looking for help with:
- **Front-End Development**: Improve the UI/UX by moving beyond Streamlit.
@@ -161,16 +254,58 @@ We would love your contributions! Specifically, we're looking for help with:
- **Feature Development**: Letβs make the coolest note-taking tool together!
See more at [CONTRIBUTING](CONTRIBUTING.md)
-## π License
+
+
-Your contributions, feature requests, and bug reports are always welcome. Let's build a research tool that respects our privacy and makes learning truly open for everyone. β¨
----
-This project uses the following third-party libraries:
+
+## Contact
-- [Podcastfy](https://github.com/souzatharsis/podcastfy) - Licensed under the Apache License 2.0
\ No newline at end of file
+Luis Novo - [@lfnovo](https://twitter.com/lfnovo)
+
+
+
+
+
+
+[contributors-shield]: https://img.shields.io/github/contributors/lfnovo/open-notebook.svg?style=for-the-badge
+[contributors-url]: https://github.com/lfnovo/open-notebook/graphs/contributors
+[forks-shield]: https://img.shields.io/github/forks/lfnovo/open-notebook.svg?style=for-the-badge
+[forks-url]: https://github.com/lfnovo/open-notebook/network/members
+[stars-shield]: https://img.shields.io/github/stars/lfnovo/open-notebook.svg?style=for-the-badge
+[stars-url]: https://github.com/lfnovo/open-notebook/stargazers
+[issues-shield]: https://img.shields.io/github/issues/lfnovo/open-notebook.svg?style=for-the-badge
+[issues-url]: https://github.com/lfnovo/open-notebook/issues
+[license-shield]: https://img.shields.io/github/license/lfnovo/open-notebook.svg?style=for-the-badge
+[license-url]: https://github.com/lfnovo/open-notebook/blob/master/LICENSE.txt
+[linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=for-the-badge&logo=linkedin&colorB=555
+[linkedin-url]: https://linkedin.com/in/lfnovo
+[product-screenshot]: images/screenshot.png
+[Streamlit]: https://img.shields.io/badge/Streamlit-FF4B4B?style=for-the-badge&logo=streamlit&logoColor=white
+[Streamlit-url]: https://streamlit.io/
+[Python]: https://img.shields.io/badge/Python-3776AB?style=for-the-badge&logo=python&logoColor=white
+[Python-url]: https://www.python.org/
+[LangChain]: https://img.shields.io/badge/LangChain-3A3A3A?style=for-the-badge&logo=chainlink&logoColor=white
+[LangChain-url]: https://www.langchain.com/
+[SurrealDB]: https://img.shields.io/badge/SurrealDB-FF5E00?style=for-the-badge&logo=databricks&logoColor=white
+[SurrealDB-url]: https://surrealdb.com/
diff --git a/app_home.py b/app_home.py
index 9840c6c..a8df6c9 100644
--- a/app_home.py
+++ b/app_home.py
@@ -1,43 +1,35 @@
import streamlit as st
-from open_notebook.database.migrate import MigrationManager
+from open_notebook.domain.base import ObjectModel
+from open_notebook.exceptions import NotFoundError
+from pages.components import (
+ note_panel,
+ source_embedding_panel,
+ source_insight_panel,
+ source_panel,
+)
+from pages.stream_app.utils import setup_page
-# from open_notebook.config import DEFAULT_MODELS
-from open_notebook.domain.models import DefaultModels
-from stream_app.utils import version_sidebar
+setup_page("π Open Notebook", sidebar_state="collapsed")
-default_models = DefaultModels.load()
-
-version_sidebar()
-mm = MigrationManager()
-if mm.needs_migration:
- st.warning("The Open Notebook database needs a migration to run properly.")
- if st.button("Run Migration"):
- mm.run_migration_up()
- st.success("Migration successful")
- st.rerun()
-elif (
- not default_models.default_chat_model
- or not default_models.default_transformation_model
-):
- st.warning(
- "You don't have default chat and transformation models selected. Please, select them on the settings page."
- )
-elif not default_models.default_embedding_model:
- st.warning(
- "You don't have a default embedding model selected. Vector search will not be possible and your assistant will be less able to answer your queries. Please, select one on the settings page."
- )
-elif not default_models.default_speech_to_text_model:
- st.warning(
- "You don't have a default speech to text model selected. Your assistant will not be able to transcribe audio. Please, select one on the settings page."
- )
-elif not default_models.default_text_to_speech_model:
- st.warning(
- "You don't have a default text to speech model selected. Your assistant will not be able to generate audio and podcasts. Please, select one on the settings page."
- )
-elif not default_models.large_context_model:
- st.warning(
- "You don't have a large context model selected. Your assistant will not be able to process large documents. Please, select one on the settings page."
- )
-else:
+if "object_id" not in st.query_params:
st.switch_page("pages/2_π_Notebooks.py")
+ st.stop()
+
+object_id = st.query_params["object_id"]
+try:
+ obj = ObjectModel.get(object_id)
+except NotFoundError:
+ st.switch_page("pages/2_π_Notebooks.py")
+ st.stop()
+
+obj_type = object_id.split(":")[0]
+
+if obj_type == "note":
+ note_panel(object_id)
+elif obj_type == "source":
+ source_panel(object_id)
+elif obj_type == "source_insight":
+ source_insight_panel(object_id)
+elif obj_type == "source_embedding":
+ source_embedding_panel(object_id)
diff --git a/docs/assets/hero.svg b/docs/assets/hero.svg
new file mode 100644
index 0000000..8701347
--- /dev/null
+++ b/docs/assets/hero.svg
@@ -0,0 +1,60 @@
+
+
\ No newline at end of file
diff --git a/migrations/2.surrealql b/migrations/2.surrealql
new file mode 100644
index 0000000..66c1017
--- /dev/null
+++ b/migrations/2.surrealql
@@ -0,0 +1 @@
+DEFINE FIELD IF NOT EXISTS note_type ON TABLE note TYPE option;
diff --git a/migrations/2_down.surrealql b/migrations/2_down.surrealql
new file mode 100644
index 0000000..d0213a0
--- /dev/null
+++ b/migrations/2_down.surrealql
@@ -0,0 +1 @@
+REMOVE FIELD IF EXISTS note_type ON TABLE note;
diff --git a/migrations/3.surrealql b/migrations/3.surrealql
new file mode 100644
index 0000000..f2a067f
--- /dev/null
+++ b/migrations/3.surrealql
@@ -0,0 +1,145 @@
+
+DEFINE TABLE IF NOT EXISTS chat_session SCHEMALESS;
+
+DEFINE TABLE IF NOT EXISTS refers_to
+TYPE RELATION
+FROM chat_session TO notebook;
+
+REMOVE FUNCTION IF EXISTS fn::vector_search;
+
+DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_count: int, $sources: bool, $show_notes: bool, $min_similarity: float) {
+ let $source_embedding_search =
+ IF $sources {(
+ SELECT
+ id,
+ source.title as title,
+ content,
+ source.id as parent_id,
+ vector::similarity::cosine(embedding, $query) as similarity
+ FROM source_embedding
+ WHERE vector::similarity::cosine(embedding, $query) >= $min_similarity
+ ORDER BY similarity DESC
+ LIMIT $match_count
+ )}
+ ELSE { [] };
+
+ let $source_insight_search =
+ IF $sources {(
+ SELECT
+ id,
+ insight_type + ' - ' + source.title as title,
+ content,
+ source.id as parent_id,
+ vector::similarity::cosine(embedding, $query) as similarity
+ FROM source_insight
+ WHERE vector::similarity::cosine(embedding, $query) >= $min_similarity
+ ORDER BY similarity DESC
+ LIMIT $match_count
+ )}
+ ELSE { [] };
+
+
+ let $note_content_search =
+ IF $show_notes {(
+ SELECT
+ id,
+ title,
+ content,
+ id as parent_id,
+ vector::similarity::cosine(embedding, $query) as similarity
+ FROM note
+ WHERE vector::similarity::cosine(embedding, $query) >= $min_similarity
+ ORDER BY similarity DESC
+ LIMIT $match_count
+ )}
+ ELSE { [] };
+
+
+ let $all_results = array::union(
+ array::union($source_embedding_search, $source_insight_search),
+ $note_content_search
+ );
+
+
+ RETURN (
+ SELECT
+ id, title, content, parent_id,
+ math::max(similarity) as similarity
+ FROM $all_results
+ GROUP BY id
+ ORDER BY similarity DESC
+ LIMIT $match_count
+ );
+};
+
+
+REMOVE FUNCTION IF EXISTS fn::text_search;
+
+
+DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) {
+
+ let $source_title_search =
+ IF $sources {(
+ SELECT id, title,
+ search::highlight('`', '`', 1) as content,
+ id as parent_id,
+ math::max(search::score(1)) AS relevance
+ FROM source
+ WHERE title @1@ $query_text
+ GROUP BY id)}
+ ELSE { [] };
+
+ let $source_embedding_search =
+ IF $sources {(
+ SELECT id as id, source.title as title, search::highlight('`', '`', 1) as content, source.id as parent_id, math::max(search::score(1)) AS relevance
+ FROM source_embedding
+ WHERE content @1@ $query_text
+ GROUP BY id)}
+ ELSE { [] };
+
+ let $source_full_search =
+ IF $sources {(
+ SELECT source.id as id, source.title as title, search::highlight('`', '`', 1) as content, source.id as parent_id, math::max(search::score(1)) AS relevance
+ FROM source
+ WHERE full_text @1@ $query_text
+ GROUP BY id)}
+ ELSE { [] };
+
+ let $source_insight_search =
+ IF $sources {(
+ SELECT id, insight_type + " - " + source.title as title, search::highlight('`', '`', 1) as content, source.id as parent_id, math::max(search::score(1)) AS relevance
+ FROM source_insight
+ WHERE content @1@ $query_text
+ GROUP BY id)}
+ ELSE { [] };
+
+ let $note_title_search =
+ IF $show_notes {(
+ SELECT id, title, search::highlight('`', '`', 1) as content, id as parent_id, math::max(search::score(1)) AS relevance
+ FROM note
+ WHERE title @1@ $query_text
+ GROUP BY id)}
+ ELSE { [] };
+
+ let $note_content_search =
+ IF $show_notes {(
+ SELECT id, title, search::highlight('`', '`', 1) as content, id as parent_id, math::max(search::score(1)) AS relevance
+ FROM note
+ WHERE content @1@ $query_text
+ GROUP BY id)}
+ ELSE { [] };
+
+ let $source_chunk_results = array::union($source_embedding_search, $source_full_search);
+
+ let $source_asset_results = array::union($source_title_search, $source_insight_search);
+
+ let $source_results = array::union($source_chunk_results, $source_asset_results );
+ let $note_results = array::union($note_title_search, $note_content_search );
+ let $final_results = array::union($source_results, $note_results );
+
+ RETURN (SELECT id, title, content, parent_id, math::max(relevance) as relevance from $final_results
+ where id is not None
+group by id, title, content, parent_id ORDER BY relevance DESC LIMIT $match_count);
+
+
+};
diff --git a/migrations/3_down.surrealql b/migrations/3_down.surrealql
new file mode 100644
index 0000000..b8438e0
--- /dev/null
+++ b/migrations/3_down.surrealql
@@ -0,0 +1,110 @@
+REMOVE TABLE IF EXISTS chat_session;
+
+REMOVE TABLE IF EXISTS refers_to;
+
+
+REMOVE FUNCTION fn::vector_search;
+
+
+DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array, $match_count: int, $sources:bool, $show_notes:bool) {
+
+ let $source_embedding_search =
+ IF $sources {(
+ SELECT source as item_id, content, vector::similarity::cosine(embedding, $query) as similarity
+ FROM source_embedding LIMIT $match_count)}
+ ELSE { [] };
+
+
+ let $source_insight_search =
+ IF $sources {(
+ SELECT source as item_id, content, vector::similarity::cosine(embedding, $query) as similarity
+ FROM source_insight LIMIT $match_count)}
+ ELSE { [] };
+
+
+ let $note_content_search =
+ IF $show_notes {(
+ SELECT id as item_id, content, vector::similarity::cosine(embedding, $query) as similarity
+ FROM note LIMIT $match_count)}
+
+ ELSE { [] };
+
+ let $source_chunk_results = array::union($source_embedding_search, $source_insight_search);
+
+ let $source_results = array::union($source_chunk_results, $source_insight_search);
+
+ let $note_results = $note_content_search;
+ let $final_results = array::union($source_results, $note_results );
+
+ RETURN (SELECT item_id, math::max(similarity) as similarity from $final_results
+ group by item_id ORDER BY similarity DESC LIMIT $match_count);
+
+
+};
+
+REMOVE FUNCTION fn::text_search;
+
+
+DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) {
+
+ let $source_title_search =
+ IF $sources {(
+ SELECT id as item_id, math::max(search::score(1)) AS relevance
+ FROM source
+ WHERE title @1@ $query_text
+ GROUP BY item_id)}
+ ELSE { [] };
+
+ let $source_embedding_search =
+ IF $sources {(
+ SELECT source as item_id, math::max(search::score(1)) AS relevance
+ FROM source_embedding
+ WHERE content @1@ $query_text
+ GROUP BY item_id)}
+ ELSE { [] };
+
+ let $source_full_search =
+ IF $sources {(
+ SELECT source as item_id, math::max(search::score(1)) AS relevance
+ FROM source
+ WHERE full_text @1@ $query_text
+ GROUP BY item_id)}
+ ELSE { [] };
+
+ let $source_insight_search =
+ IF $sources {(
+ SELECT source as item_id, math::max(search::score(1)) AS relevance
+ FROM source_insight
+ WHERE content @1@ $query_text
+ GROUP BY item_id)}
+ ELSE { [] };
+
+ let $note_title_search =
+ IF $show_notes {(
+ SELECT id as item_id, math::max(search::score(1)) AS relevance
+ FROM note
+ WHERE title @1@ $query_text
+ GROUP BY item_id)}
+ ELSE { [] };
+
+ let $note_content_search =
+ IF $show_notes {(
+ SELECT id as item_id, math::max(search::score(1)) AS relevance
+ FROM note
+ WHERE content @1@ $query_text
+ GROUP BY item_id)}
+ ELSE { [] };
+
+ let $source_chunk_results = array::union($source_embedding_search, $source_full_search);
+
+ let $source_asset_results = array::union($source_title_search, $source_insight_search);
+
+ let $source_results = array::union($source_chunk_results, $source_asset_results );
+ let $note_results = array::union($note_title_search, $note_content_search );
+ let $final_results = array::union($source_results, $note_results );
+
+ RETURN (SELECT item_id, math::max(relevance) as relevance from $final_results
+ group by item_id ORDER BY relevance DESC LIMIT $match_count);
+
+
+};
diff --git a/open_notebook/config.py b/open_notebook/config.py
index c07ab03..096850c 100644
--- a/open_notebook/config.py
+++ b/open_notebook/config.py
@@ -3,9 +3,6 @@ import os
import yaml
from loguru import logger
-from open_notebook.domain.models import DefaultModels
-from open_notebook.models import get_model
-
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
config_path = os.path.join(project_root, "open_notebook_config.yaml")
@@ -34,22 +31,3 @@ os.makedirs(UPLOADS_FOLDER, exist_ok=True)
# PODCASTS FOLDER
PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts"
os.makedirs(PODCASTS_FOLDER, exist_ok=True)
-
-
-def load_default_models():
- default_models = DefaultModels.load()
- embedding_model = (
- get_model(default_models.default_embedding_model, model_type="embedding")
- if default_models.default_embedding_model
- else None
- )
-
- speech_to_text_model = (
- get_model(
- default_models.default_speech_to_text_model, model_type="speech_to_text"
- )
- if default_models.default_speech_to_text_model
- else None
- )
-
- return default_models, embedding_model, speech_to_text_model
diff --git a/open_notebook/database/migrate.py b/open_notebook/database/migrate.py
index 7d99fbd..085caf4 100644
--- a/open_notebook/database/migrate.py
+++ b/open_notebook/database/migrate.py
@@ -18,8 +18,18 @@ class MigrationManager:
database=os.environ["SURREAL_DATABASE"],
encrypted=False, # Set to True if using SSL
)
- self.up_migrations = [Migration.from_file("migrations/1.surrealql")]
- self.down_migrations = [Migration.from_file("migrations/1_down.surrealql")]
+ self.up_migrations = [
+ Migration.from_file("migrations/1.surrealql"),
+ Migration.from_file("migrations/2.surrealql"),
+ Migration.from_file("migrations/3.surrealql"),
+ ]
+ self.down_migrations = [
+ Migration.from_file(
+ "migrations/1_down.surrealql",
+ ),
+ Migration.from_file("migrations/2_down.surrealql"),
+ Migration.from_file("migrations/3_down.surrealql"),
+ ]
self.runner = MigrationRunner(
up_migrations=self.up_migrations,
down_migrations=self.down_migrations,
diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py
index 2aa648e..8514e89 100644
--- a/open_notebook/domain/base.py
+++ b/open_notebook/domain/base.py
@@ -1,5 +1,5 @@
from datetime import datetime
-from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar
+from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, cast
from loguru import logger
from pydantic import BaseModel, ValidationError, field_validator
@@ -29,15 +29,26 @@ class ObjectModel(BaseModel):
@classmethod
def get_all(cls: Type[T], order_by=None) -> List[T]:
try:
+ # If called from a specific subclass, use its table_name
+ if cls.table_name:
+ target_class = cls
+ table_name = cls.table_name
+ else:
+ # This path is taken if called directly from ObjectModel
+ raise InvalidInputError(
+ "get_all() must be called from a specific model class"
+ )
+
if order_by:
order = f" ORDER BY {order_by}"
else:
order = ""
- result = repo_query(f"SELECT * FROM {cls.table_name} {order}")
+
+ result = repo_query(f"SELECT * FROM {table_name} {order}")
objects = []
for obj in result:
try:
- objects.append(cls(**obj))
+ objects.append(target_class(**obj))
except Exception as e:
logger.critical(f"Error creating object: {str(e)}")
@@ -52,14 +63,44 @@ class ObjectModel(BaseModel):
if not id:
raise InvalidInputError("ID cannot be empty")
try:
+ # Get the table name from the ID (everything before the first colon)
+ table_name = id.split(":")[0] if ":" in id else id
+
+ # If we're calling from a specific subclass and IDs match, use that class
+ if cls.table_name and cls.table_name == table_name:
+ target_class: Type[T] = cls
+ else:
+ # Otherwise, find the appropriate subclass based on table_name
+ found_class = cls._get_class_by_table_name(table_name)
+ if not found_class:
+ raise InvalidInputError(f"No class found for table {table_name}")
+ target_class = cast(Type[T], found_class)
+
result = repo_query(f"SELECT * FROM {id}")
if result:
- return cls(**result[0])
- return None
+ return target_class(**result[0])
+ else:
+ raise NotFoundError(f"{table_name} with id {id} not found")
except Exception as e:
- logger.error(f"Error fetching {cls.table_name} with id {id}: {str(e)}")
+ logger.error(f"Error fetching object with id {id}: {str(e)}")
logger.exception(e)
- raise NotFoundError(f"{cls.table_name} with id {id} not found")
+ raise NotFoundError(f"Object with id {id} not found - {str(e)}")
+
+ @classmethod
+ def _get_class_by_table_name(cls, table_name: str) -> Optional[Type["ObjectModel"]]:
+ """Find the appropriate subclass based on table_name."""
+
+ def get_all_subclasses(c: Type["ObjectModel"]) -> List[Type["ObjectModel"]]:
+ all_subclasses: List[Type["ObjectModel"]] = []
+ for subclass in c.__subclasses__():
+ all_subclasses.append(subclass)
+ all_subclasses.extend(get_all_subclasses(subclass))
+ return all_subclasses
+
+ for subclass in get_all_subclasses(ObjectModel):
+ if hasattr(subclass, "table_name") and subclass.table_name == table_name:
+ return subclass
+ return None
def needs_embedding(self) -> bool:
return False
@@ -68,12 +109,12 @@ class ObjectModel(BaseModel):
return None
def save(self) -> None:
- from open_notebook.config import load_default_models
+ from open_notebook.domain.models import model_manager
+ from open_notebook.models import EmbeddingModel
- DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
+ EMBEDDING_MODEL: EmbeddingModel = model_manager.embedding_model
try:
- logger.debug(f"Validating {self.__class__.__name__}")
self.model_validate(self.model_dump(), strict=True)
data = self._prepare_save_data()
data["updated"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -88,7 +129,11 @@ class ObjectModel(BaseModel):
logger.debug("Creating new record")
repo_result = repo_create(self.__class__.table_name, data)
else:
- data["created"] = self.created.strftime("%Y-%m-%d %H:%M:%S")
+ data["created"] = (
+ self.created.strftime("%Y-%m-%d %H:%M:%S")
+ if isinstance(self.created, datetime)
+ else self.created
+ )
logger.debug(f"Updating record with id {self.id}")
repo_result = repo_update(self.id, data)
@@ -114,8 +159,6 @@ class ObjectModel(BaseModel):
def _prepare_save_data(self) -> Dict[str, Any]:
data = self.model_dump()
- # del data["created"]
- # del data["updated"]
return {key: value for key, value in data.items() if value is not None}
def delete(self) -> bool:
@@ -148,3 +191,24 @@ class ObjectModel(BaseModel):
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
+
+
+class RecordModel(BaseModel):
+ record_id: ClassVar[str]
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.load()
+
+ def load(self):
+ result = repo_query(f"SELECT * FROM {self.record_id};")
+ if result:
+ result = result[0]
+ for key, value in result.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ return self
+
+ def update(self, data):
+ repo_update(self.record_id, data)
+ return self.load()
diff --git a/open_notebook/domain/models.py b/open_notebook/domain/models.py
index 5699147..9cb84ca 100644
--- a/open_notebook/domain/models.py
+++ b/open_notebook/domain/models.py
@@ -1,12 +1,15 @@
-from typing import ClassVar, Optional
+from typing import ClassVar, Dict, Optional
-from pydantic import BaseModel
-
-from open_notebook.database.repository import (
- repo_query,
- repo_update,
+from open_notebook.database.repository import repo_query
+from open_notebook.domain.base import ObjectModel, RecordModel
+from open_notebook.models import (
+ MODEL_CLASS_MAP,
+ EmbeddingModel,
+ LanguageModel,
+ ModelType,
+ SpeechToTextModel,
+ TextToSpeechModel,
)
-from open_notebook.domain.base import ObjectModel
class Model(ObjectModel):
@@ -23,7 +26,9 @@ class Model(ObjectModel):
return [Model(**model) for model in models]
-class DefaultModels(BaseModel):
+class DefaultModels(RecordModel):
+ record_id: ClassVar[str] = "open_notebook:default_models"
+
default_chat_model: Optional[str] = None
default_transformation_model: Optional[str] = None
large_context_model: Optional[str] = None
@@ -31,16 +36,138 @@ class DefaultModels(BaseModel):
default_speech_to_text_model: Optional[str] = None
# default_vision_model: Optional[str] = None
default_embedding_model: Optional[str] = None
+ default_tools_model: Optional[str] = None
- @classmethod
- def load(self):
- result = repo_query("SELECT * FROM open_notebook:default_models;")
- if result:
- result = result[0]
- dm = DefaultModels(**result)
- return dm
- return DefaultModels()
- @classmethod
- def update(self, data):
- repo_update("open_notebook:default_models", data)
+class ModelManager:
+ _instance = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super(ModelManager, cls).__new__(cls)
+ return cls._instance
+
+ def __init__(self):
+ if not hasattr(self, "_initialized"):
+ self._initialized = True
+ self._model_cache: Dict[str, ModelType] = {}
+ self._default_models = None
+ self.refresh_defaults()
+
+ def get_model(self, model_id: str, **kwargs) -> ModelType:
+ cache_key = f"{model_id}:{str(kwargs)}"
+
+ if cache_key in self._model_cache:
+ cached_model = self._model_cache[cache_key]
+ if not isinstance(
+ cached_model,
+ (LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel),
+ ):
+ raise TypeError(
+ f"Cached model is of unexpected type: {type(cached_model)}"
+ )
+ return cached_model
+
+ assert model_id, "Model ID cannot be empty"
+ model: Model = Model.get(model_id)
+
+ if not model:
+ raise ValueError(f"Model with ID {model_id} not found")
+
+ if not model.type or model.type not in MODEL_CLASS_MAP:
+ raise ValueError(f"Invalid model type: {model.type}")
+
+ provider_map = MODEL_CLASS_MAP[model.type]
+ if model.provider not in provider_map:
+ raise ValueError(
+ f"Provider {model.provider} not compatible with {model.type} models"
+ )
+
+ model_class = provider_map[model.provider]
+ model_instance = model_class(model_name=model.name, **kwargs)
+
+ # Special handling for language models that need langchain conversion
+ if model.type == "language":
+ model_instance = model_instance
+
+ self._model_cache[cache_key] = model_instance
+ return model_instance
+
+ def refresh_defaults(self):
+ """Refresh the default models from the database"""
+ self._default_models = DefaultModels()
+
+ @property
+ def defaults(self) -> DefaultModels:
+ """Get the default models configuration"""
+ if not self._default_models:
+ self.refresh_defaults()
+ if not self._default_models:
+ raise RuntimeError("Failed to initialize default models configuration")
+ return self._default_models
+
+ @property
+ def speech_to_text(self, **kwargs) -> SpeechToTextModel:
+ """Get the default speech-to-text model"""
+ model = self.get_default_model("speech_to_text", **kwargs)
+ if not isinstance(model, SpeechToTextModel):
+ raise TypeError(f"Expected SpeechToTextModel but got {type(model)}")
+ return model
+
+ @property
+ def text_to_speech(self, **kwargs) -> TextToSpeechModel:
+ """Get the default text-to-speech model"""
+ model = self.get_default_model("text_to_speech", **kwargs)
+ if not isinstance(model, TextToSpeechModel):
+ raise TypeError(f"Expected TextToSpeechModel but got {type(model)}")
+ return model
+
+ @property
+ def embedding_model(self, **kwargs) -> EmbeddingModel:
+ """Get the default embedding model"""
+ model = self.get_default_model("embedding", **kwargs)
+ if not isinstance(model, EmbeddingModel):
+ raise TypeError(f"Expected EmbeddingModel but got {type(model)}")
+ return model
+
+ def get_default_model(self, model_type: str, **kwargs) -> ModelType:
+ """
+ Get the default model for a specific type.
+
+ Args:
+ model_type: The type of model to retrieve (e.g., 'chat', 'embedding', etc.)
+ **kwargs: Additional arguments to pass to the model constructor
+ """
+ model_id = None
+
+ if model_type == "chat":
+ model_id = self.defaults.default_chat_model
+ elif model_type == "transformation":
+ model_id = (
+ self.defaults.default_transformation_model
+ or self.defaults.default_chat_model
+ )
+ elif model_type == "tools":
+ model_id = (
+ self.defaults.default_tools_model or self.defaults.default_chat_model
+ )
+ elif model_type == "embedding":
+ model_id = self.defaults.default_embedding_model
+ elif model_type == "text_to_speech":
+ model_id = self.defaults.default_text_to_speech_model
+ elif model_type == "speech_to_text":
+ model_id = self.defaults.default_speech_to_text_model
+ elif model_type == "large_context":
+ model_id = self.defaults.large_context_model
+
+ if not model_id:
+ raise ValueError(f"No default model configured for type: {model_type}")
+
+ return self.get_model(model_id, **kwargs)
+
+ def clear_cache(self):
+ """Clear the model cache"""
+ self._model_cache.clear()
+
+
+model_manager = ModelManager()
diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py
index d23ee04..e473193 100644
--- a/open_notebook/domain/notebook.py
+++ b/open_notebook/domain/notebook.py
@@ -1,22 +1,17 @@
-import os
from typing import Any, ClassVar, Dict, List, Literal, Optional
-from langchain_core.runnables.config import RunnableConfig
from loguru import logger
from pydantic import BaseModel, Field, field_validator
-from open_notebook.config import load_default_models
from open_notebook.database.repository import (
- repo_create,
repo_query,
)
from open_notebook.domain.base import ObjectModel
+from open_notebook.domain.models import model_manager
from open_notebook.exceptions import (
DatabaseOperationError,
InvalidInputError,
)
-from open_notebook.graphs.multipattern import graph as pattern_graph
-from open_notebook.graphs.recursive_toc import graph as toc_graph
from open_notebook.utils import split_text, surreal_clean
@@ -71,16 +66,69 @@ class Notebook(ObjectModel):
logger.exception(e)
raise DatabaseOperationError(e)
+ @property
+ def chat_sessions(self) -> List["ChatSession"]:
+ try:
+ srcs = repo_query(f"""
+ select * from (
+ select
+ <- chat_session as chat_session
+ from refers_to
+ where out={self.id}
+ fetch chat_session
+ )
+ order by chat_session.updated desc
+ """)
+ return (
+ [ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
+ )
+ except Exception as e:
+ logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
+ logger.exception(e)
+ raise DatabaseOperationError(e)
+
class Asset(BaseModel):
file_path: Optional[str] = None
url: Optional[str] = None
+class SourceEmbedding(ObjectModel):
+ table_name: ClassVar[str] = "source_embedding"
+ content: str
+
+ @property
+ def source(self) -> "Source":
+ try:
+ src = repo_query(f"""
+ select source.* from {self.id} fetch source
+
+ """)
+ return Source(**src[0]["source"])
+ except Exception as e:
+ logger.error(f"Error fetching source for embedding {self.id}: {str(e)}")
+ logger.exception(e)
+ raise DatabaseOperationError(e)
+
+
class SourceInsight(ObjectModel):
+ table_name: ClassVar[str] = "source_insight"
insight_type: str
content: str
+ @property
+ def source(self) -> "Source":
+ try:
+ src = repo_query(f"""
+ select source.* from {self.id} fetch source
+
+ """)
+ return Source(**src[0]["source"])
+ except Exception as e:
+ logger.error(f"Error fetching source for insight {self.id}: {str(e)}")
+ logger.exception(e)
+ raise DatabaseOperationError(e)
+
class Source(ObjectModel):
table_name: ClassVar[str] = "source"
@@ -96,12 +144,28 @@ class Source(ObjectModel):
return dict(
id=self.id,
title=self.title,
- insights=self.insights,
+ insights=[insight.model_dump() for insight in self.insights],
full_text=self.full_text,
)
else:
return dict(id=self.id, title=self.title, insights=self.insights)
+ @property
+ def embedded_chunks(self) -> int:
+ try:
+ result = repo_query(
+ f"""
+ select count() as chunks from source_embedding where source={self.id} GROUP ALL
+ """
+ )
+ if len(result) == 0:
+ return 0
+ return result[0]["chunks"]
+ except Exception as e:
+ logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
+ logger.exception(e)
+ raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
+
@property
def insights(self) -> List[SourceInsight]:
try:
@@ -121,34 +185,14 @@ class Source(ObjectModel):
raise InvalidInputError("Notebook ID must be provided")
return self.relate("reference", notebook_id)
- def save_chunks(self, text: str) -> None:
- if not text:
- raise InvalidInputError("Text cannot be empty")
- try:
- chunks = split_text(text, chunk=500000, overlap=1000)
- logger.debug(f"Split into {len(chunks)} chunks")
- for i, chunk in enumerate(chunks):
- logger.debug(f"Saving chunk {i}")
- data = {"source": self.id, "order": i, "content": surreal_clean(chunk)}
- repo_create(
- "source_chunk",
- data,
- )
- except Exception as e:
- logger.exception(e)
- logger.error(f"Error saving chunks for source {self.id}: {str(e)}")
- raise DatabaseOperationError(e)
-
def vectorize(self) -> None:
- DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
+ EMBEDDING_MODEL = model_manager.embedding_model
try:
if not self.full_text:
return
chunks = split_text(
self.full_text,
- chunk=int(os.environ.get("EMBEDDING_CHUNK_SIZE", 1000)),
- overlap=int(os.environ.get("EMBEDDING_CHUNK_OVERLAP", 1000)),
)
logger.debug(f"Split into {len(chunks)} chunks")
@@ -169,29 +213,29 @@ class Source(ObjectModel):
logger.exception(e)
raise DatabaseOperationError(e)
- @classmethod
- def search(cls, query: str) -> List[Dict[str, Any]]:
- if not query:
- raise InvalidInputError("Search query cannot be empty")
- try:
- result = repo_query(
- """
- SELECT * omit full_text
- FROM source
- WHERE string::lowercase(title) CONTAINS $query or title @@ $query
- OR string::lowercase(summary) CONTAINS $query or summary @@ $query
- OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
- """,
- {"query": query},
- )
- return result
- except Exception as e:
- logger.error(f"Error searching sources: {str(e)}")
- logger.exception(e)
- raise DatabaseOperationError("Failed to search sources")
+ # @classmethod
+ # def search(cls, query: str) -> List[Dict[str, Any]]:
+ # if not query:
+ # raise InvalidInputError("Search query cannot be empty")
+ # try:
+ # result = repo_query(
+ # """
+ # SELECT * omit full_text
+ # FROM source
+ # WHERE string::lowercase(title) CONTAINS $query or title @@ $query
+ # OR string::lowercase(summary) CONTAINS $query or summary @@ $query
+ # OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
+ # """,
+ # {"query": query},
+ # )
+ # return result
+ # except Exception as e:
+ # logger.error(f"Error searching sources: {str(e)}")
+ # logger.exception(e)
+ # raise DatabaseOperationError("Failed to search sources")
def add_insight(self, insight_type: str, content: str) -> Any:
- DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
+ EMBEDDING_MODEL = model_manager.embedding_model
if not insight_type or not content:
raise InvalidInputError("Insight type and content must be provided")
@@ -211,34 +255,11 @@ class Source(ObjectModel):
logger.error(f"Error adding insight to source {self.id}: {str(e)}")
raise DatabaseOperationError(e)
- # todo: move this to content processing pipeline as a major graph
- def generate_toc_and_title(self) -> "Source":
- DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
- try:
- config = RunnableConfig(configurable=dict(thread_id=self.id))
- result = toc_graph.invoke({"content": self.full_text}, config=config)
- self.add_insight("Table of Contents", surreal_clean(result["toc"]))
- if not self.title:
- transformations = [
- "Based on the Table of Contents below, please provide a Title for this content, with max 15 words"
- ]
- output = pattern_graph.invoke(
- dict(content_stack=[result["toc"]], transformations=transformations)
- )
- self.title = surreal_clean(output["output"])
- self.save()
- return self
- except Exception as e:
- logger.error(f"Error summarizing source {self.id}: {str(e)}")
- logger.exception(e)
- raise DatabaseOperationError(e)
-
class Note(ObjectModel):
table_name: ClassVar[str] = "note"
title: Optional[str] = None
- note_type: Optional[Literal["human", "ai"]] = "human"
+ note_type: Optional[Literal["human", "ai"]] = None
content: Optional[str] = None
@field_validator("content")
@@ -272,6 +293,16 @@ class Note(ObjectModel):
return self.content
+class ChatSession(ObjectModel):
+ table_name: ClassVar[str] = "chat_session"
+ title: Optional[str] = None
+
+ def relate_to_notebook(self, notebook_id: str) -> Any:
+ if not notebook_id:
+ raise InvalidInputError("Notebook ID must be provided")
+ return self.relate("refers_to", notebook_id)
+
+
def text_search(keyword: str, results: int, source: bool = True, note: bool = True):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")
@@ -286,21 +317,142 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
except Exception as e:
logger.error(f"Error performing text search: {str(e)}")
logger.exception(e)
- raise DatabaseOperationError("Failed to perform text search")
+ raise DatabaseOperationError(e)
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")
try:
+ EMBEDDING_MODEL = model_manager.embedding_model
+ embed = EMBEDDING_MODEL.embed(keyword)
results = repo_query(
"""
- SELECT * FROM fn::vector_search($keyword, $results, $source, $note);
+ SELECT * FROM fn::vector_search($embed, $results, $source, $note, 0.15);
""",
- {"keyword": keyword, "results": results, "source": source, "note": note},
+ {"embed": embed, "results": results, "source": source, "note": note},
)
return results
except Exception as e:
logger.error(f"Error performing vector search: {str(e)}")
logger.exception(e)
- raise DatabaseOperationError("Failed to perform vector search")
+ raise DatabaseOperationError(e)
+
+
+def hybrid_search(
+ keyword_search: List[str],
+ embed_search: List[str],
+ results: int = 50,
+ source: bool = True,
+ note: bool = True,
+ max_chunks_per_doc: int = 3,
+ min_results_per_query: int = 3,
+) -> Dict[str, List[Dict]]:
+ if not keyword_search and not embed_search:
+ raise InvalidInputError("At least one search term required")
+
+ # Process keyword searches
+ all_keyword_results = {} # Dictionary to store results per keyword
+ for keyword in keyword_search:
+ try:
+ search_results = text_search(keyword, results, source, note)
+ # Sort results by relevance
+ sorted_results = sorted(
+ search_results, key=lambda x: x.get("relevance", 0), reverse=True
+ )
+ # Group by parent_id and limit chunks per document
+ seen_parent_ids = {}
+ filtered_results = []
+ for result in sorted_results:
+ parent_id = result["parent_id"]
+ if parent_id not in seen_parent_ids:
+ seen_parent_ids[parent_id] = 1
+ filtered_results.append(result)
+ elif seen_parent_ids[parent_id] < max_chunks_per_doc:
+ seen_parent_ids[parent_id] += 1
+ filtered_results.append(result)
+ all_keyword_results[keyword] = filtered_results
+ except Exception as e:
+ logger.warning(f"Error in keyword search for term '{keyword}': {str(e)}")
+ continue
+
+ # Ensure minimum results from each keyword query
+ keyword_results = []
+ remaining_slots = results
+
+ # First pass: add minimum results from each query
+ for keyword, query_results in all_keyword_results.items():
+ keyword_results.extend(query_results[:min_results_per_query])
+ remaining_slots -= min(len(query_results), min_results_per_query)
+
+ # Second pass: fill remaining slots with best results
+ all_remaining = []
+ for keyword, query_results in all_keyword_results.items():
+ all_remaining.extend(query_results[min_results_per_query:])
+
+ # Sort remaining by relevance and add until we hit the limit
+ all_remaining = sorted(
+ all_remaining, key=lambda x: x.get("relevance", 0), reverse=True
+ )
+ seen_ids = {r["id"] for r in keyword_results}
+ for result in all_remaining:
+ if remaining_slots <= 0:
+ break
+ if result["id"] not in seen_ids:
+ keyword_results.append(result)
+ seen_ids.add(result["id"])
+ remaining_slots -= 1
+
+ # Process vector searches with the same approach
+ all_vector_results = {} # Dictionary to store results per embedding
+ for embed in embed_search:
+ try:
+ search_results = vector_search(embed, results, source, note)
+ # Sort results by similarity
+ sorted_results = sorted(
+ search_results, key=lambda x: x.get("similarity", 0), reverse=True
+ )
+ # Group by parent_id and limit chunks per document
+ seen_parent_ids = {}
+ filtered_results = []
+ for result in sorted_results:
+ parent_id = result["parent_id"]
+ if parent_id not in seen_parent_ids:
+ seen_parent_ids[parent_id] = 1
+ filtered_results.append(result)
+ elif seen_parent_ids[parent_id] < max_chunks_per_doc:
+ seen_parent_ids[parent_id] += 1
+ filtered_results.append(result)
+ all_vector_results[embed] = filtered_results
+ except Exception as e:
+ logger.warning(f"Error in vector search for term '{embed}': {str(e)}")
+ continue
+
+ # Ensure minimum results from each vector query
+ vector_results = []
+ remaining_slots = results
+
+ # First pass: add minimum results from each query
+ for embed, query_results in all_vector_results.items():
+ vector_results.extend(query_results[:min_results_per_query])
+ remaining_slots -= min(len(query_results), min_results_per_query)
+
+ # Second pass: fill remaining slots with best results
+ all_remaining = []
+ for embed, query_results in all_vector_results.items():
+ all_remaining.extend(query_results[min_results_per_query:])
+
+ # Sort remaining by similarity and add until we hit the limit
+ all_remaining = sorted(
+ all_remaining, key=lambda x: x.get("similarity", 0), reverse=True
+ )
+ seen_ids = {r["id"] for r in vector_results}
+ for result in all_remaining:
+ if remaining_slots <= 0:
+ break
+ if result["id"] not in seen_ids:
+ vector_results.append(result)
+ seen_ids.add(result["id"])
+ remaining_slots -= 1
+
+ return {"keyword_results": keyword_results, "vector_results": vector_results}
diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py
index 7735cd8..5e3b4ca 100644
--- a/open_notebook/graphs/chat.py
+++ b/open_notebook/graphs/chat.py
@@ -9,11 +9,10 @@ from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict
-from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE, load_default_models
+from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
from open_notebook.domain.notebook import Notebook
-from open_notebook.graphs.utils import run_pattern
-
-DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
+from open_notebook.graphs.utils import provision_langchain_model
+from open_notebook.prompter import Prompter
class ThreadState(TypedDict):
@@ -24,15 +23,10 @@ class ThreadState(TypedDict):
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
- model_id = config.get("configurable", {}).get(
- "model_id", DEFAULT_MODELS.default_chat_model
- )
- ai_message = run_pattern(
- "chat",
- model_id,
- messages=state["messages"],
- state=state,
- )
+ system_prompt = Prompter(prompt_template="chat").render(data=state)
+ payload = [system_prompt] + state.get("messages", [])
+ model = provision_langchain_model(str(payload), config, "chat", max_tokens=2000)
+ ai_message = model.invoke(payload)
return {"messages": ai_message}
diff --git a/open_notebook/graphs/content_processing/__init__.py b/open_notebook/graphs/content_processing/__init__.py
index 2c772dc..915da23 100644
--- a/open_notebook/graphs/content_processing/__init__.py
+++ b/open_notebook/graphs/content_processing/__init__.py
@@ -48,15 +48,6 @@ def file_type(state: SourceState):
return return_dict
-# def _get_title(url):
-# """
-# Get the content of a URL
-# """
-# response = extract_url(dict(url=url))
-# if "title" in response:
-# return response["title"]
-
-
def file_type_edge(data: SourceState):
assert data.get("identified_type"), "Type not identified"
identified_type = data["identified_type"]
diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py
index b7c31be..3f99277 100644
--- a/open_notebook/graphs/content_processing/audio.py
+++ b/open_notebook/graphs/content_processing/audio.py
@@ -4,9 +4,10 @@ from math import ceil
from loguru import logger
from pydub import AudioSegment
-from open_notebook.config import load_default_models
+from open_notebook.domain.models import model_manager
from open_notebook.graphs.content_processing.state import SourceState
+# todo: remove reference to model_manager
# future: parallelize the transcription process
@@ -72,7 +73,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
def extract_audio(data: SourceState):
- DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
+ SPEECH_TO_TEXT_MODEL = model_manager.speech_to_text
input_audio_path = data.get("file_path")
audio_files = []
diff --git a/open_notebook/graphs/doc_query.py b/open_notebook/graphs/doc_query.py
deleted file mode 100644
index 7f0fb31..0000000
--- a/open_notebook/graphs/doc_query.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from langchain_core.runnables import (
- RunnableConfig,
-)
-from langgraph.graph import END, START, StateGraph
-from typing_extensions import TypedDict
-
-from open_notebook.config import load_default_models
-from open_notebook.domain.notebook import Note, Notebook, Source
-from open_notebook.graphs.utils import run_pattern
-
-DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
-
-class DocQueryState(TypedDict):
- doc_id: str
- doc_content: str
- question: str
- answer: str
- notebook: Notebook
-
-
-def call_model(state: dict, config: RunnableConfig) -> dict:
- model_id = config.get("configurable", {}).get(
- "model_id", DEFAULT_MODELS.default_transformation_model
- )
- return {"answer": run_pattern("doc_query", model_id, state)}
-
-
-# todo: there is probably a better way to do this and avoid repetition
-def get_content(state: DocQueryState) -> dict:
- doc_id = state["doc_id"]
- if "note:" in doc_id:
- doc: Note = Note.get(id=doc_id)
- elif "source:" in doc_id:
- doc: Source = Source.get(id=doc_id)
- doc_content = doc.get_context("long") if doc else None
- return {"doc_content": doc_content}
-
-
-agent_state = StateGraph(DocQueryState)
-agent_state.add_node("get_content", get_content)
-agent_state.add_node("agent", call_model)
-agent_state.add_edge(START, "get_content")
-agent_state.add_edge("get_content", "agent")
-agent_state.add_edge("agent", END)
-
-graph = agent_state.compile()
diff --git a/open_notebook/graphs/multipattern.py b/open_notebook/graphs/multipattern.py
index 75d499a..9b95638 100644
--- a/open_notebook/graphs/multipattern.py
+++ b/open_notebook/graphs/multipattern.py
@@ -7,11 +7,8 @@ from langchain_core.runnables import (
from langgraph.graph import END, START, StateGraph
from typing_extensions import Annotated, TypedDict
-from open_notebook.config import load_default_models
from open_notebook.graphs.utils import run_pattern
-DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
class PatternChainState(TypedDict):
content_stack: Annotated[Sequence[str], operator.add]
@@ -20,9 +17,6 @@ class PatternChainState(TypedDict):
def call_model(state: dict, config: RunnableConfig) -> dict:
- model_id = config.get("configurable", {}).get(
- "model_id", DEFAULT_MODELS.default_transformation_model
- )
patterns = state["patterns"]
current_transformation = patterns.pop(0)
if current_transformation.startswith("patterns/"):
@@ -36,7 +30,7 @@ def call_model(state: dict, config: RunnableConfig) -> dict:
transformation_result = run_pattern(
pattern_name=current_transformation,
- model_id=model_id,
+ config=config,
state=input_args,
)
return {
diff --git a/open_notebook/graphs/pattern.py b/open_notebook/graphs/pattern.py
deleted file mode 100644
index 65e9858..0000000
--- a/open_notebook/graphs/pattern.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from langchain_core.runnables import (
- RunnableConfig,
-)
-from langgraph.graph import END, START, StateGraph
-from typing_extensions import TypedDict
-
-from open_notebook.config import load_default_models
-from open_notebook.graphs.utils import run_pattern
-
-DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
-
-class PatternState(TypedDict):
- input_text: str
- pattern: str
- output: str
-
-
-def call_model(state: dict, config: RunnableConfig) -> dict:
- model_id = config.get("configurable", {}).get(
- "model_id", DEFAULT_MODELS.default_transformation_model
- )
- return {
- "output": run_pattern(
- pattern_name=state["pattern"],
- model_id=model_id,
- state=state,
- )
- }
-
-
-agent_state = StateGraph(PatternState)
-agent_state.add_node("agent", call_model)
-agent_state.add_edge(START, "agent")
-agent_state.add_edge("agent", END)
-graph = agent_state.compile()
diff --git a/open_notebook/graphs/rag.py b/open_notebook/graphs/rag.py
new file mode 100644
index 0000000..24dc435
--- /dev/null
+++ b/open_notebook/graphs/rag.py
@@ -0,0 +1,44 @@
+from typing import Annotated
+
+from langchain_core.runnables import (
+ RunnableConfig,
+)
+from langgraph.graph import START, StateGraph
+from langgraph.graph.message import add_messages
+from langgraph.prebuilt import ToolNode, tools_condition
+from typing_extensions import TypedDict
+
+from open_notebook.graphs.tools import repository_search
+from open_notebook.graphs.utils import provision_langchain_model
+from open_notebook.prompter import Prompter
+
+tools = [repository_search]
+tool_node = ToolNode(tools)
+
+
+class ThreadState(TypedDict):
+ messages: Annotated[list, add_messages]
+ # notebook: Optional[Notebook]
+ # context: Optional[str]
+ # context_config: Optional[dict]
+
+
+def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
+ system_prompt = Prompter(prompt_template="rag").render(data=state)
+ payload = [system_prompt] + state.get("messages", [])
+ model = provision_langchain_model(str(payload), config, "tools", max_tokens=2000)
+ model = model.bind_tools(tools)
+ ai_message = model.invoke(payload)
+ return {"messages": ai_message}
+
+
+agent_state = StateGraph(ThreadState)
+agent_state.add_node("agent", call_model_with_messages)
+agent_state.add_node("tools", tool_node)
+agent_state.add_edge(START, "agent")
+agent_state.add_conditional_edges(
+ "agent",
+ tools_condition,
+)
+agent_state.add_edge("tools", "agent")
+graph = agent_state.compile()
diff --git a/open_notebook/graphs/recursive_toc.py b/open_notebook/graphs/recursive_toc.py
deleted file mode 100644
index 02c5326..0000000
--- a/open_notebook/graphs/recursive_toc.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import os
-from typing import List, Literal
-
-from langchain_core.runnables import (
- RunnableConfig,
-)
-from langgraph.graph import END, START, StateGraph
-from typing_extensions import TypedDict
-
-from open_notebook.config import load_default_models
-from open_notebook.graphs.utils import run_pattern
-from open_notebook.utils import split_text
-
-DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
-
-class TocState(TypedDict):
- chunks: List[str]
- content: str
- toc: str
-
-
-def build_chunks(state: TocState) -> dict:
- """
- Split the input text into chunks.
- """
- return {
- "chunks": split_text(
- state["content"],
- chunk=int(os.environ.get("SUMMARY_CHUNK_SIZE", 200000)),
- overlap=int(os.environ.get("SUMMARY_CHUNK_OVERLAP", 1000)),
- )
- }
-
-
-def setup_next_chunk(state: TocState) -> dict:
- """
- Move the next item in the chunk to the processing area
- """
- state["content"] = state["chunks"].pop(0)
- return {"chunks": state["chunks"], "content": state["content"]}
-
-
-def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: ignore
- """
- Checks whether there are more chunks to process.
- """
- if len(state["chunks"]) > 0:
- return "get_chunk"
- return END
-
-
-def call_model(state: TocState, config: RunnableConfig) -> dict:
- model_id = config.get("configurable", {}).get(
- "model_id", DEFAULT_MODELS.default_transformation_model
- )
- return {
- "toc": run_pattern(
- pattern_name="recursive_toc",
- model_id=model_id,
- state=state,
- ).content
- }
-
-
-agent_state = StateGraph(TocState)
-agent_state.add_node("setup_chunk", build_chunks)
-agent_state.add_edge(START, "setup_chunk")
-agent_state.add_conditional_edges(
- "setup_chunk",
- chunk_condition,
-)
-agent_state.add_node("get_chunk", setup_next_chunk)
-agent_state.add_node("agent", call_model)
-agent_state.add_edge("get_chunk", "agent")
-agent_state.add_conditional_edges(
- "agent",
- chunk_condition,
-)
-
-graph = agent_state.compile()
diff --git a/open_notebook/graphs/summary.py b/open_notebook/graphs/summary.py
deleted file mode 100644
index d4e0659..0000000
--- a/open_notebook/graphs/summary.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import os
-from typing import List, Literal
-
-from langchain_core.output_parsers import PydanticOutputParser
-from langchain_core.runnables import (
- RunnableConfig,
-)
-from langgraph.graph import END, START, StateGraph
-from pydantic import BaseModel
-from typing_extensions import TypedDict
-
-from open_notebook.config import load_default_models
-from open_notebook.graphs.utils import run_pattern
-from open_notebook.utils import split_text
-
-DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
-
-class SummaryResponse(BaseModel):
- """This is schema of your response. Please provide a JSON object with the enclosed keys"""
-
- summary: str
- topics: List[str]
- title: str
-
-
-class SummaryState(TypedDict):
- chunks: List[str]
- content: str
- output: SummaryResponse
-
-
-def build_chunks(state: SummaryState) -> dict:
- """
- Split the input text into chunks.
- """
- return {
- "chunks": split_text(
- state["content"],
- chunk=int(os.environ.get("SUMMARY_CHUNK_SIZE", 200000)),
- overlap=int(os.environ.get("SUMMARY_CHUNK_OVERLAP", 1000)),
- )
- }
-
-
-def setup_next_chunk(state: SummaryState) -> dict:
- """
- Move the next item in the chunk to the processing area
- """
- state["content"] = state["chunks"].pop(0)
- return {"chunks": state["chunks"], "content": state["content"]}
-
-
-def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type: ignore
- """
- Checks whether there are more chunks to process.
- """
- if len(state["chunks"]) > 0:
- return "get_chunk"
- return END
-
-
-def call_model(state: dict, config: RunnableConfig) -> dict:
- model_id = config.get("configurable", {}).get(
- "model_id", DEFAULT_MODELS.default_transformation_model
- )
- parser = PydanticOutputParser(pydantic_object=SummaryResponse)
- return {
- "output": run_pattern(
- pattern_name="summarize",
- model_id=model_id,
- state=state,
- parser=parser,
- )
- }
-
-
-agent_state = StateGraph(SummaryState)
-agent_state.add_node("setup_chunk", build_chunks)
-agent_state.add_edge(START, "setup_chunk")
-agent_state.add_conditional_edges(
- "setup_chunk",
- chunk_condition,
-)
-agent_state.add_node("get_chunk", setup_next_chunk)
-agent_state.add_node("agent", call_model)
-agent_state.add_edge("get_chunk", "agent")
-agent_state.add_conditional_edges(
- "agent",
- chunk_condition,
-)
-
-graph = agent_state.compile()
diff --git a/open_notebook/graphs/tools.py b/open_notebook/graphs/tools.py
index 636e25b..620fac4 100644
--- a/open_notebook/graphs/tools.py
+++ b/open_notebook/graphs/tools.py
@@ -1,8 +1,12 @@
from datetime import datetime
+from typing import List
from langchain.tools import tool
+from open_notebook.domain.notebook import hybrid_search
+
+# todo: turn this into a system prompt variable
@tool
def get_current_timestamp() -> str:
"""
@@ -13,14 +17,11 @@ def get_current_timestamp() -> str:
@tool
-def doc_query(doc_id: str, question: str):
+def repository_search(keyword_searches: List[str], vector_searches: List[str]) -> str:
"""
- name: doc_query
- Use this tool if you need to investigate into a particular document.
- Another LLM will read the document and answer the question that you might have.
- Use this when the user question cannot be answered with the content you have in context.
+ name: repository_search
+ Makes a search in the content repository for the given query.
+ keyword_searches: List[str] - A list of search terms to search for using keyword search.
+ vector_searches: List[str] - A list of search terms to search for using vector search.
"""
- from open_notebook.graphs.doc_query import graph
-
- result = graph.invoke({"doc_id": doc_id, "question": question})
- return result["answer"]
+ return hybrid_search(keyword_searches, vector_searches, 20)
diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py
index 39871ee..07365ea 100644
--- a/open_notebook/graphs/utils.py
+++ b/open_notebook/graphs/utils.py
@@ -1,52 +1,52 @@
-from langchain.output_parsers import OutputFixingParser
+from langchain_core.language_models.chat_models import BaseChatModel
+from langchain_core.messages import BaseMessage
from loguru import logger
-from open_notebook.config import load_default_models
-from open_notebook.models import get_model
+from open_notebook.domain.models import model_manager
+from open_notebook.models.llms import LanguageModel
from open_notebook.prompter import Prompter
from open_notebook.utils import token_count
+def provision_langchain_model(content, config, default_type, **kwargs) -> BaseChatModel:
+ """
+ Returns the best model to use based on the context size and on whether there is a specific model being requested in Config.
+ If context > 105_000, returns the large_context_model
+ If model_id is specified in Config, returns that model
+ Otherwise, returns the default model for the given type
+ """
+ tokens = token_count(content)
+
+ if tokens > 105_000:
+ logger.debug(
+ f"Using large context model because the content has {tokens} tokens"
+ )
+ model = model_manager.get_default_model("large_context", **kwargs)
+ elif config.get("configurable", {}).get("model_id"):
+ model = model_manager.get_model(
+ config.get("configurable", {}).get("model_id"), **kwargs
+ )
+ else:
+ model = model_manager.get_default_model(default_type, **kwargs)
+
+ assert isinstance(model, LanguageModel), f"Model is not a LanguageModel: {model}"
+ return model.to_langchain()
+
+
+# todo: turn into a graph
def run_pattern(
pattern_name: str,
- model_id=None,
+ config,
messages=[],
state: dict = {},
parser=None,
- output_fixing_model_id=None,
-) -> dict:
+) -> BaseMessage:
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
data=state
)
- DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
- tokens = token_count(str(system_prompt) + str(messages))
+ payload = [system_prompt] + messages
+ chain = provision_langchain_model(str(payload), config, "transformation")
- if tokens > 105_000:
- model_id = DEFAULT_MODELS.large_context_model
- logger.debug(
- f"Using large context model ({model_id}) because the content has {tokens} tokens"
- )
-
- model_id = (
- model_id
- or DEFAULT_MODELS.default_transformation_model
- or DEFAULT_MODELS.default_chat_model
- )
-
- chain = get_model(model_id, model_type="language")
- if parser:
- chain = chain | parser
-
- if output_fixing_model_id and parser:
- output_fix_model = get_model(output_fixing_model_id, model_type="language")
- chain = chain | OutputFixingParser.from_llm(
- parser=parser,
- llm=output_fix_model,
- )
-
- if len(messages) > 0:
- response = chain.invoke([system_prompt] + messages)
- else:
- response = chain.invoke(system_prompt)
+ response = chain.invoke(payload)
return response
diff --git a/open_notebook/models/__init__.py b/open_notebook/models/__init__.py
index 9473c9e..7e95d9c 100644
--- a/open_notebook/models/__init__.py
+++ b/open_notebook/models/__init__.py
@@ -1,5 +1,7 @@
-from open_notebook.domain.models import Model
+from typing import Dict, Type, Union
+
from open_notebook.models.embedding_models import (
+ EmbeddingModel,
GeminiEmbeddingModel,
OllamaEmbeddingModel,
OpenAIEmbeddingModel,
@@ -8,6 +10,7 @@ from open_notebook.models.embedding_models import (
from open_notebook.models.llms import (
AnthropicLanguageModel,
GeminiLanguageModel,
+ LanguageModel,
LiteLLMLanguageModel,
OllamaLanguageModel,
OpenAILanguageModel,
@@ -15,14 +18,22 @@ from open_notebook.models.llms import (
VertexAILanguageModel,
VertexAnthropicLanguageModel,
)
-from open_notebook.models.speech_to_text_models import OpenAISpeechToTextModel
+from open_notebook.models.speech_to_text_models import (
+ OpenAISpeechToTextModel,
+ SpeechToTextModel,
+)
from open_notebook.models.text_to_speech_models import (
ElevenLabsTextToSpeechModel,
OpenAITextToSpeechModel,
+ TextToSpeechModel,
)
-# Unified model class map with type information
-MODEL_CLASS_MAP = {
+ModelType = Union[LanguageModel, EmbeddingModel, SpeechToTextModel, TextToSpeechModel]
+
+
+ProviderMap = Dict[str, Type[ModelType]]
+
+MODEL_CLASS_MAP: Dict[str, ProviderMap] = {
"language": {
"ollama": OllamaLanguageModel,
"openrouter": OpenRouterLanguageModel,
@@ -48,36 +59,11 @@ MODEL_CLASS_MAP = {
},
}
-
-def get_model(model_id, model_type="language", **kwargs):
- """
- Get a model instance based on model_id and type.
-
- Args:
- model_id: The ID of the model to retrieve
- model_type: Type of model ('language', 'embedding', or 'speech_to_text')
- **kwargs: Additional arguments to pass to the model constructor
- """
- assert model_id, "Model ID cannot be empty"
- model: Model = Model.get(model_id)
-
- if not model:
- raise ValueError(f"Model with ID {model_id} not found")
-
- if model_type not in MODEL_CLASS_MAP:
- raise ValueError(f"Invalid model type: {model_type}")
-
- provider_map = MODEL_CLASS_MAP[model_type]
- if model.provider not in provider_map:
- raise ValueError(
- f"Provider {model.provider} not compatible with {model_type} models"
- )
-
- model_class = provider_map[model.provider]
- model_instance = model_class(model_name=model.name, **kwargs)
-
- # Special handling for language models that need langchain conversion
- if model_type == "language":
- return model_instance.to_langchain()
-
- return model_instance
+__all__ = [
+ "MODEL_CLASS_MAP",
+ "EmbeddingModel",
+ "LanguageModel",
+ "SpeechToTextModel",
+ "TextToSpeechModel",
+ "ModelType",
+]
diff --git a/open_notebook/utils.py b/open_notebook/utils.py
index 340762e..b25a79c 100644
--- a/open_notebook/utils.py
+++ b/open_notebook/utils.py
@@ -5,30 +5,11 @@ from urllib.parse import urlparse
import requests
import tomli
-from langchain_text_splitters import CharacterTextSplitter
+from langchain_text_splitters import RecursiveCharacterTextSplitter
from packaging.version import parse as parse_version
-def split_text(txt: str, chunk=1000, overlap=0, separator=" "):
- """
- Split the input text into chunks.
-
- Args:
- txt (str): The input text to be split.
- chunk (int): The size of each chunk. Default is 1000.
- overlap (int): The number of characters to overlap between chunks. Default is 0.
- separator (str): The separator to use when splitting the text. Default is " ".
-
- Returns:
- list: A list of text chunks.
- """
- text_splitter = CharacterTextSplitter(
- chunk_size=chunk, chunk_overlap=overlap, separator=separator
- )
- return text_splitter.split_text(txt)
-
-
-def token_count(input_string):
+def token_count(input_string) -> int:
"""
Count the number of tokens in the input string using the 'o200k_base' encoding.
@@ -46,7 +27,7 @@ def token_count(input_string):
return token_count
-def token_cost(token_count, cost_per_million=0.150):
+def token_cost(token_count, cost_per_million=0.150) -> float:
"""
Calculate the cost of tokens based on the token count and cost per million tokens.
@@ -60,21 +41,60 @@ def token_cost(token_count, cost_per_million=0.150):
return cost_per_million * (token_count / 1_000_000)
-def remove_non_ascii(text):
+def split_text(txt: str, chunk_size=500):
+ """
+ Split the input text into chunks.
+
+ Args:
+ txt (str): The input text to be split.
+ chunk (int): The size of each chunk. Default is 1000.
+ overlap (int): The number of characters to overlap between chunks. Default is 0.
+ separator (str): The separator to use when splitting the text. Default is " ".
+
+ Returns:
+ list: A list of text chunks.
+ """
+ overlap = int(chunk_size * 0.15)
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=chunk_size,
+ chunk_overlap=overlap,
+ length_function=token_count,
+ separators=[
+ "\n\n",
+ "\n",
+ ".",
+ ",",
+ " ",
+ "\u200b", # Zero-width space
+ "\uff0c", # Fullwidth comma
+ "\u3001", # Ideographic comma
+ "\uff0e", # Fullwidth full stop
+ "\u3002", # Ideographic full stop
+ "",
+ ],
+ )
+ return text_splitter.split_text(txt)
+
+
+def remove_non_ascii(text) -> str:
return re.sub(r"[^\x00-\x7F]+", "", text)
-def remove_non_printable(text):
+def remove_non_printable(text) -> str:
+ # Replace any special Unicode whitespace characters with a regular space
+ text = re.sub(r"[\u2000-\u200B\u202F\u205F\u3000]", " ", text)
+
# Remove control characters, except newlines and tabs
text = "".join(
char for char in text if unicodedata.category(char)[0] != "C" or char in "\n\t"
)
+ # Replace non-breaking spaces with regular spaces
text = text.replace("\xa0", " ").strip()
# Keep letters (including accented ones), numbers, spaces, newlines, tabs, and basic punctuation
return re.sub(r"[^\w\s.,!?\-\n\t]", "", text, flags=re.UNICODE)
-def surreal_clean(text):
+def surreal_clean(text) -> str:
"""
Clean the input text by removing non-ASCII and non-printable characters,
and adjusting colon placement for SurrealDB compatibility.
diff --git a/pages/2_π_Notebooks.py b/pages/2_π_Notebooks.py
index b0a6c07..f1e3354 100644
--- a/pages/2_π_Notebooks.py
+++ b/pages/2_π_Notebooks.py
@@ -1,25 +1,23 @@
import streamlit as st
from humanize import naturaltime
-from open_notebook.config import load_default_models
from open_notebook.domain.notebook import Notebook
-from stream_app.chat import chat_sidebar
-from stream_app.note import add_note, note_card
-from stream_app.source import add_source, source_card
-from stream_app.utils import setup_stream_state, version_sidebar
+from pages.stream_app.chat import chat_sidebar
+from pages.stream_app.note import add_note, note_card
+from pages.stream_app.source import add_source, source_card
+from pages.stream_app.utils import setup_page, setup_stream_state
-st.set_page_config(
- layout="wide", page_title="π Open Notebook", initial_sidebar_state="expanded"
-)
-
-version_sidebar()
+setup_page("π Open Notebook")
-def notebook_header(current_notebook):
+def notebook_header(current_notebook: Notebook):
+ """
+ Defines the header of the notebook page, including the ability to edit the notebook's name and description.
+ """
c1, c2, c3 = st.columns([8, 2, 2])
c1.header(current_notebook.name)
if c2.button("Back to the list", icon="π"):
- st.session_state["current_notebook"] = None
+ st.session_state["current_notebook_id"] = None
st.rerun()
if c3.button("Refresh", icon="π"):
@@ -54,26 +52,23 @@ def notebook_header(current_notebook):
st.toast("Notebook unarchived", icon="ποΈ")
if c3.button("Delete forever", type="primary", icon="β οΈ"):
current_notebook.delete()
- st.session_state["current_notebook"] = None
+ st.session_state["current_notebook_id"] = None
st.rerun()
-def notebook_page(current_notebook_id):
- current_notebook: Notebook = Notebook.get(current_notebook_id)
- if not current_notebook:
- st.error("Notebook not found")
- return
- if current_notebook_id not in st.session_state.keys():
- st.session_state[current_notebook_id] = current_notebook
+def notebook_page(current_notebook: Notebook):
+ # Guarantees that we have an entry for this notebook in the session state
+ if current_notebook.id not in st.session_state:
+ st.session_state[current_notebook.id] = {"notebook": current_notebook}
+
+ # sets up the active session
+ current_session = setup_stream_state(
+ current_notebook=current_notebook,
+ )
- session_id = st.session_state["active_session"]
- st.session_state[session_id]["notebook"] = current_notebook
sources = current_notebook.sources
notes = current_notebook.notes
- # Load the default models dynamically
- DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
notebook_header(current_notebook)
work_tab, chat_tab = st.columns([4, 2])
@@ -82,18 +77,18 @@ def notebook_page(current_notebook_id):
with sources_tab:
with st.container(border=True):
if st.button("Add Source", icon="β"):
- add_source(session_id)
+ add_source(current_notebook.id)
for source in sources:
- source_card(session_id=session_id, source=source)
+ source_card(source=source, notebook_id=current_notebook.id)
with notes_tab:
with st.container(border=True):
if st.button("Write a Note", icon="π"):
- add_note(session_id)
+ add_note(current_notebook.id)
for note in notes:
- note_card(session_id=session_id, note=note)
+ note_card(note=note, notebook_id=current_notebook.id)
with chat_tab:
- chat_sidebar(session_id=session_id)
+ chat_sidebar(current_notebook=current_notebook, current_session=current_session)
def notebook_list_item(notebook):
@@ -104,40 +99,50 @@ def notebook_list_item(notebook):
)
st.write(notebook.description)
if st.button("Open", key=f"open_notebook_{notebook.id}"):
- setup_stream_state(notebook.id)
- st.session_state["current_notebook"] = notebook.id
+ st.session_state["current_notebook_id"] = notebook.id
st.rerun()
-if "current_notebook" not in st.session_state:
- st.session_state["current_notebook"] = None
+if "current_notebook_id" not in st.session_state:
+ st.session_state["current_notebook_id"] = None
-if st.session_state["current_notebook"]:
- notebook_page(st.session_state["current_notebook"])
+# todo: get the notebook, check if it exists and if it's archived
+if st.session_state["current_notebook_id"]:
+ current_notebook: Notebook = Notebook.get(st.session_state["current_notebook_id"])
+ if not current_notebook:
+ st.error("Notebook not found")
+ st.stop()
+ notebook_page(current_notebook)
st.stop()
st.title("π My Notebooks")
-st.caption("Here are all your notebooks")
+st.caption(
+ "Notebooks are a great way to organize your thoughts, ideas, and sources. You can create notebooks for different research topics and projects, to create new articles, etc. "
+)
+
+with st.expander("β **New Notebook**"):
+ new_notebook_title = st.text_input("New Notebook Name")
+ new_notebook_description = st.text_area(
+ "Description",
+ placeholder="Explain the purpose of this notebook. The more details the better.",
+ )
+ if st.button("Create a new Notebook", icon="β"):
+ notebook = Notebook(
+ name=new_notebook_title, description=new_notebook_description
+ )
+ notebook.save()
+ st.toast("Notebook created successfully", icon="π")
notebooks = Notebook.get_all(order_by="updated desc")
+archived_notebooks = [nb for nb in notebooks if nb.archived]
for notebook in notebooks:
if notebook.archived:
continue
notebook_list_item(notebook)
-with st.expander("β **New Notebook**"):
- new_notebook_title = st.text_input("New Notebook Name")
- new_notebook_description = st.text_area("Description")
- if st.button("Create a new Notebook", icon="β"):
- notebook = Notebook(
- name=new_notebook_title, description=new_notebook_description
- )
- notebook.save()
- st.rerun()
-
-archived_notebooks = [nb for nb in notebooks if nb.archived]
if len(archived_notebooks) > 0:
with st.expander(f"**ποΈ {len(archived_notebooks)} archived Notebooks**"):
+ st.write("βΉ Archived Notebooks can still be accessed and used in search.")
for notebook in archived_notebooks:
notebook_list_item(notebook)
diff --git a/pages/3_π_Ask_and_Search.py b/pages/3_π_Ask_and_Search.py
new file mode 100644
index 0000000..c085867
--- /dev/null
+++ b/pages/3_π_Ask_and_Search.py
@@ -0,0 +1,68 @@
+import streamlit as st
+
+from open_notebook.domain.models import Model
+from open_notebook.domain.notebook import text_search, vector_search
+from open_notebook.graphs.rag import graph as rag_graph
+from pages.stream_app.utils import convert_source_references, setup_page
+
+setup_page("π Search")
+
+ask_tab, search_tab = st.tabs(["Ask Your Knowledge Base (beta)", "Search"])
+
+if "search_results" not in st.session_state:
+ st.session_state["search_results"] = []
+
+
+def results_card(item):
+ score = item.get("relevance", item.get("similarity", item.get("score", 0)))
+ with st.expander(f"[{score:.2f}] **{item['title']}**"):
+ st.markdown(f"**{item['content']}**")
+ st.write(item["id"])
+ st.write(item["parent_id"])
+
+
+with ask_tab:
+ st.subheader("Ask Your Knowledge Base (beta)")
+ st.caption(
+ "The LLM will answer your query based on the documents in your knowledge base. "
+ )
+ st.warning(
+ "This functionality requires the use of Tools and, at this moment, works well with Open AI and Anthropic models only."
+ )
+ question = st.text_input("Question", "")
+ models = Model.get_models_by_type("language")
+ model: Model = st.selectbox("Model", models, format_func=lambda x: x.name)
+ if st.button("Ask"):
+ st.write(f"Searching for {question}")
+ messages = [question]
+ rag_results = rag_graph.invoke(
+ dict(
+ messages=messages,
+ ),
+ config=dict(configurable=dict(model_id=model.id)),
+ )
+ st.markdown(convert_source_references(rag_results["messages"][-1].content))
+ with st.expander("Details (for debugging)"):
+ st.json(rag_results)
+
+with search_tab:
+ with st.container(border=True):
+ st.subheader("π Search")
+ st.caption("Search your knowledge base for specific keywords or concepts")
+ search_term = st.text_input("Search", "")
+ search_type = st.radio("Search Type", ["Text Search", "Vector Search"])
+ search_sources = st.checkbox("Search Sources", value=True)
+ search_notes = st.checkbox("Search Notes", value=True)
+ if st.button("Search"):
+ if search_type == "Text Search":
+ st.write(f"Searching for {search_term}")
+ st.session_state["search_results"] = text_search(
+ search_term, 100, search_sources, search_notes
+ )
+ elif search_type == "Vector Search":
+ st.write(f"Searching for {search_term}")
+ st.session_state["search_results"] = vector_search(
+ search_term, 100, search_sources, search_notes
+ )
+ for item in st.session_state["search_results"]:
+ results_card(item)
diff --git a/pages/3_π_Search.py b/pages/3_π_Search.py
deleted file mode 100644
index cdfa4fc..0000000
--- a/pages/3_π_Search.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import streamlit as st
-
-from open_notebook.config import load_default_models
-from open_notebook.domain.notebook import text_search, vector_search
-from stream_app.note import note_list_item
-from stream_app.source import source_list_item
-from stream_app.utils import version_sidebar
-
-st.set_page_config(
- layout="wide", page_title="π Search", initial_sidebar_state="expanded"
-)
-version_sidebar()
-
-DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
-# search_tab, ask_tab = st.tabs(["Search", "Ask"])
-# notebooks = Notebook.get_all()
-
-if "search_results" not in st.session_state:
- st.session_state["search_results"] = []
-
-# with search_tab:
-with st.container(border=True):
- st.subheader("π Search")
- st.caption("Search your knowledge base for specific keywords or concepts")
- search_term = st.text_input("Search", "")
- search_type = st.radio("Search Type", ["Text Search", "Vector Search"])
- search_sources = st.checkbox("Search Sources", value=True)
- search_notes = st.checkbox("Search Notes", value=True)
- if st.button("Search"):
- if search_type == "Text Search":
- st.write(f"Searching for {search_term}")
- st.session_state["search_results"] = text_search(
- search_term, 100, search_sources, search_notes
- )
- elif search_type == "Vector Search":
- st.write(f"Searching for {search_term}")
- embed_query = EMBEDDING_MODEL.embed(search_term)
- st.session_state["search_results"] = vector_search(
- embed_query, 100, search_sources, search_notes
- )
- for item in st.session_state["search_results"]:
- score = item.get("relevance", item.get("similarity", 0))
- if item.get("item_id"):
- if "source:" in item["item_id"]:
- source_list_item(item["item_id"], score)
- elif "note:" in item["item_id"]:
- note_list_item(item["item_id"], score)
-
-# coming soon
-# with ask_tab:
-# with st.form(key="ask_form"):
-# st.subheader("Ask Your Knowledge Base")
-# st.caption("Let the LLM formulate an answer based on your query")
-# question = st.text_input("Your question", "")
-
-# notebooks = st.multiselect(
-# "Notebooks",
-# notebooks,
-# notebooks,
-# format_func=lambda x: x.name,
-# )
-# search_sources = st.multiselect(
-# "Use Sources",
-# ["Sources", "Notes"],
-# ["Sources", "Notes"],
-# )
-# if st.form_submit_button("Search"):
-# st.write(f"Searching for {search_term}")
diff --git a/pages/5_ποΈ_Podcasts.py b/pages/5_ποΈ_Podcasts.py
index db1273c..d15559e 100644
--- a/pages/5_ποΈ_Podcasts.py
+++ b/pages/5_ποΈ_Podcasts.py
@@ -12,13 +12,9 @@ from open_notebook.plugins.podcasts import (
engagement_techniques,
participant_roles,
)
-from stream_app.utils import version_sidebar
+from pages.stream_app.utils import setup_page
-st.set_page_config(
- layout="wide", page_title="ποΈ Podcasts", initial_sidebar_state="expanded"
-)
-
-version_sidebar()
+setup_page("ποΈ Podcasts")
text_to_speech_models = Model.get_models_by_type("text_to_speech")
diff --git a/pages/7_βοΈ_Settings.py b/pages/7_βοΈ_Settings.py
index 209814c..aad2132 100644
--- a/pages/7_βοΈ_Settings.py
+++ b/pages/7_βοΈ_Settings.py
@@ -2,14 +2,11 @@ import os
import streamlit as st
-from open_notebook.domain.models import DefaultModels, Model
+from open_notebook.domain.models import DefaultModels, Model, model_manager
from open_notebook.models import MODEL_CLASS_MAP
-from stream_app.utils import version_sidebar
+from pages.stream_app.utils import setup_page
-st.set_page_config(
- layout="wide", page_title="βοΈ Settings", initial_sidebar_state="expanded"
-)
-version_sidebar()
+setup_page("βοΈ Settings")
st.title("βοΈ Settings")
@@ -121,7 +118,7 @@ def get_selected_index(models, model_id, default=0):
with model_defaults_tab:
- default_models = DefaultModels.load().model_dump()
+ default_models = DefaultModels().model_dump()
all_models = Model.get_all()
text_generation_models = [model for model in all_models if model.type == "language"]
@@ -157,7 +154,16 @@ with model_defaults_tab:
text_generation_models, default_models.get("default_transformation_model")
),
)
- st.caption("You can override this model on individual transformations")
+ st.divider()
+ defs["default_tools_model"] = st.selectbox(
+ "Default Tools Model",
+ text_generation_models,
+ format_func=lambda x: x.name,
+ help="This model will be used for calling tools. Currently, it's best to use Open AI and Anthropic for this.",
+ index=get_selected_index(
+ text_generation_models, default_models.get("default_tools_model")
+ ),
+ )
st.divider()
defs["large_context_model"] = st.selectbox(
"Large Context Model",
@@ -219,4 +225,5 @@ with model_defaults_tab:
for k, v in defs.items():
if v:
defs[k] = v.id
- DefaultModels.update(defs)
+ DefaultModels().update(defs)
+ model_manager.refresh_defaults()
diff --git a/pages/8_π_Playground.py b/pages/8_π_Playground.py
index 53de8f9..5bcac7a 100644
--- a/pages/8_π_Playground.py
+++ b/pages/8_π_Playground.py
@@ -3,12 +3,9 @@ import yaml
from open_notebook.domain.models import Model
from open_notebook.graphs.multipattern import graph as pattern_graph
-from stream_app.utils import version_sidebar
+from pages.stream_app.utils import setup_page
-st.set_page_config(
- layout="wide", page_title="π Playground", initial_sidebar_state="expanded"
-)
-version_sidebar()
+setup_page("π Playground")
st.title("π Playground")
with open("transformations.yaml", "r") as file:
diff --git a/pages/components/__init__.py b/pages/components/__init__.py
new file mode 100644
index 0000000..d6561b5
--- /dev/null
+++ b/pages/components/__init__.py
@@ -0,0 +1,11 @@
+from pages.components.note_panel import note_panel
+from pages.components.source_embedding_panel import source_embedding_panel
+from pages.components.source_insight import source_insight_panel
+from pages.components.source_panel import source_panel
+
+__all__ = [
+ "note_panel",
+ "source_embedding_panel",
+ "source_insight_panel",
+ "source_panel",
+]
diff --git a/pages/components/note_panel.py b/pages/components/note_panel.py
new file mode 100644
index 0000000..d1cba31
--- /dev/null
+++ b/pages/components/note_panel.py
@@ -0,0 +1,30 @@
+import streamlit as st
+from loguru import logger
+from streamlit_monaco import st_monaco # type: ignore
+
+from open_notebook.domain.notebook import Note
+
+
+def note_panel(note_id, notebook_id=None):
+ note: Note = Note.get(note_id)
+ if not note:
+ raise ValueError(f"Note not fonud {note_id}")
+ t_preview, t_edit = st.tabs(["Preview", "Edit"])
+ with t_preview:
+ st.subheader(note.title)
+ st.markdown(note.content)
+ with t_edit:
+ note.title = st.text_input("Title", value=note.title)
+ note.content = st_monaco(
+ value=note.content, height="600px", language="markdown"
+ )
+ if st.button("Save", key=f"pn_edit_note_{note.id or 'new'}"):
+ logger.debug("Editing note")
+ note.save()
+ if not note.id and notebook_id:
+ note.add_to_notebook(notebook_id)
+ st.rerun()
+ if st.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"):
+ logger.debug("Deleting note")
+ note.delete()
+ st.rerun()
diff --git a/pages/components/source_embedding_panel.py b/pages/components/source_embedding_panel.py
new file mode 100644
index 0000000..4ef4a29
--- /dev/null
+++ b/pages/components/source_embedding_panel.py
@@ -0,0 +1,17 @@
+import streamlit as st
+
+from open_notebook.domain.notebook import SourceEmbedding
+
+
+def source_embedding_panel(source_embedding_id):
+ si: SourceEmbedding = SourceEmbedding.get(source_embedding_id)
+ if not si:
+ raise ValueError(f"Embedding not found {source_embedding_id}")
+ with st.container(border=True):
+ url = f"Navigator?object_id={si.source.id}"
+ st.markdown("**Original Source**")
+ st.markdown(f"{si.source.title} [link](%s)" % url)
+ st.markdown(si.content)
+ if st.button("Delete", type="primary", key=f"delete_embedding_{si.id or 'new'}"):
+ si.delete()
+ st.rerun()
diff --git a/pages/components/source_insight.py b/pages/components/source_insight.py
new file mode 100644
index 0000000..ad38793
--- /dev/null
+++ b/pages/components/source_insight.py
@@ -0,0 +1,18 @@
+import streamlit as st
+
+from open_notebook.domain.notebook import SourceInsight
+
+
+def source_insight_panel(source, notebook_id=None):
+ si: SourceInsight = SourceInsight.get(source)
+ if not si:
+ raise ValueError(f"insight not found {source}")
+ st.subheader(si.insight_type)
+ with st.container(border=True):
+ url = f"Navigator?object_id={si.source.id}"
+ st.markdown("**Original Source**")
+ st.markdown(f"{si.source.title} [link](%s)" % url)
+ st.markdown(si.content)
+ if st.button("Delete", type="primary", key=f"delete_insight_{si.id or 'new'}"):
+ si.delete()
+ st.rerun()
diff --git a/pages/components/source_panel.py b/pages/components/source_panel.py
new file mode 100644
index 0000000..8fa130d
--- /dev/null
+++ b/pages/components/source_panel.py
@@ -0,0 +1,84 @@
+import streamlit as st
+import streamlit_scrollable_textbox as stx # type: ignore
+import yaml
+from humanize import naturaltime
+
+from open_notebook.domain.notebook import Source
+from open_notebook.utils import surreal_clean
+from pages.stream_app.utils import run_patterns
+
+
+def source_panel(source_id: str, modal=False):
+ source: Source = Source.get(source_id)
+ if not source:
+ raise ValueError(f"Source not found: {source_id}")
+
+ current_title = source.title if source.title else "No Title"
+ source.title = st.text_input("Title", value=current_title)
+ if source.title != current_title:
+ st.toast("Saved new Title")
+ source.save()
+
+ process_tab, source_tab = st.tabs(["Process", "Source"])
+ with process_tab:
+ c1, c2 = st.columns([3, 1])
+ with c1:
+ title = st.empty()
+ if source.title:
+ title.subheader(source.title)
+ if source.asset and source.asset.url:
+ from_src = f"from URL: {source.asset.url}"
+ elif source.asset and source.asset.file_path:
+ from_src = f"from file: {source.asset.file_path}"
+ else:
+ from_src = "from text"
+ st.caption(f"Created {naturaltime(source.created)}, {from_src}")
+ for insight in source.insights:
+ with st.expander(f"**{insight.insight_type}**"):
+ st.markdown(insight.content)
+ if st.button(
+ "Delete", type="primary", key=f"delete_insight_{insight.id}"
+ ):
+ insight.delete()
+ st.rerun(scope="fragment" if modal else "app")
+
+ with c2:
+ with open("transformations.yaml", "r") as file:
+ transformations = yaml.safe_load(file)
+ for transformation in transformations["source_insights"]:
+ if st.button(
+ transformation["name"], help=transformation["description"]
+ ):
+ result = run_patterns(
+ source.full_text, transformation["patterns"]
+ )
+ source.add_insight(
+ transformation["insight_type"], surreal_clean(result)
+ )
+ st.rerun(scope="fragment" if modal else "app")
+
+ if st.button(
+ "Embed vectors",
+ icon="π¦Ύ",
+ disabled=source.embedded_chunks > 0,
+ help="This will generate your embedding vectors on the database for powerful search capabilities",
+ ):
+ source.vectorize()
+ st.success("Embedding complete")
+
+ chk_delete = st.checkbox(
+ "ποΈ Delete source", key=f"delete_source_{source.id}", value=False
+ )
+ if chk_delete:
+ st.warning(
+ "Source will be deleted with all its insights and embeddings"
+ )
+ if st.button(
+ "Delete", type="primary", key=f"bt_delete_source_{source.id}"
+ ):
+ source.delete()
+ st.rerun()
+
+ with source_tab:
+ st.subheader("Content")
+ stx.scrollableTextbox(source.full_text, height=300)
diff --git a/stream_app/__init__.py b/pages/stream_app/__init__.py
similarity index 100%
rename from stream_app/__init__.py
rename to pages/stream_app/__init__.py
diff --git a/pages/stream_app/chat.py b/pages/stream_app/chat.py
new file mode 100644
index 0000000..71c0437
--- /dev/null
+++ b/pages/stream_app/chat.py
@@ -0,0 +1,172 @@
+from typing import Union
+
+import humanize
+import streamlit as st
+from langchain_core.runnables import RunnableConfig
+
+from open_notebook.domain.base import ObjectModel
+from open_notebook.domain.notebook import ChatSession, Note, Notebook, Source
+from open_notebook.graphs.chat import graph as chat_graph
+from open_notebook.plugins.podcasts import PodcastConfig
+from open_notebook.utils import token_count
+from pages.stream_app.utils import (
+ convert_source_references,
+ create_session_for_notebook,
+)
+
+from .note import make_note_from_chat
+
+
+# todo: build a smarter, more robust context manager function
+def build_context(notebook_id):
+ st.session_state[notebook_id]["context"] = dict(note=[], source=[])
+
+ for id, status in st.session_state[notebook_id]["context_config"].items():
+ if not id:
+ continue
+
+ item_type, item_id = id.split(":")
+ if item_type not in ["note", "source"]:
+ continue
+
+ if "not in" in status:
+ continue
+
+ item: Union[Note, Source] = ObjectModel.get(id)
+
+ if not item:
+ continue
+ if "summary" in status:
+ st.session_state[notebook_id]["context"][item_type] += [
+ item.get_context(context_size="short")
+ ]
+ elif "full content" in status:
+ st.session_state[notebook_id]["context"][item_type] += [
+ item.get_context(context_size="long")
+ ]
+
+ return st.session_state[notebook_id]["context"]
+
+
+def execute_chat(txt_input, context, current_session):
+ current_state = st.session_state[current_session.id]
+ current_state["messages"] += [txt_input]
+ current_state["context"] = context
+ result = chat_graph.invoke(
+ input=current_state,
+ config=RunnableConfig(configurable={"thread_id": current_session.id}),
+ )
+ current_session.save()
+ return result
+
+
+def chat_sidebar(current_notebook: Notebook, current_session: ChatSession):
+ context = build_context(notebook_id=current_notebook.id)
+ tokens = token_count(
+ str(context) + str(st.session_state[current_session.id]["messages"])
+ )
+ chat_tab, podcast_tab = st.tabs(["Chat", "Podcast"])
+ with st.expander(f"Context ({tokens} tokens), {len(str(context))} chars"):
+ st.json(context)
+ with podcast_tab:
+ with st.container(border=True):
+ podcast_configs = PodcastConfig.get_all()
+ podcast_config_names = [pd.name for pd in podcast_configs]
+ if len(podcast_configs) == 0:
+ st.warning("No podcast configurations found")
+ else:
+ template = st.selectbox("Pick a template", podcast_config_names)
+ selected_template = next(
+ filter(lambda x: x.name == template, podcast_configs)
+ )
+ episode_name = st.text_input("Episode Name")
+ instructions = st.text_area(
+ "Instructions", value=selected_template.user_instructions
+ )
+ if len(context.get("note", [])) + len(context.get("source", [])) == 0:
+ st.warning(
+ "No notes or sources found in context. You don't want a boring podcast, right? So, add some context first."
+ )
+ else:
+ if st.button("Generate"):
+ with st.spinner("Go grab a coffee, almost there..."):
+ selected_template.generate_episode(
+ episode_name=episode_name,
+ text=context,
+ instructions=instructions,
+ )
+ st.success("Episode generated successfully")
+ st.page_link("pages/5_ποΈ_Podcasts.py", label="ποΈ Go to Podcasts")
+ with chat_tab:
+ with st.expander(
+ f"**Session:** {current_session.title} - {humanize.naturaltime(current_session.updated)}"
+ ):
+ new_session_name = st.text_input(
+ "Current Session",
+ key="new_session_name",
+ value=current_session.title,
+ )
+ c1, c2 = st.columns(2)
+ if c1.button("Rename", key="rename_session"):
+ current_session.title = new_session_name
+ current_session.save()
+ st.rerun()
+ if c2.button("Delete", key="delete_session_1"):
+ current_session.delete()
+ st.session_state[current_notebook.id]["active_session"] = None
+ st.rerun()
+ st.divider()
+ new_session_name = st.text_input(
+ "New Session Name",
+ key="new_session_name_f",
+ placeholder="Enter a name for the new session...",
+ )
+ st.caption("If no name provided, we'll use the current date.")
+ if st.button("Create New Session", key="create_new_session"):
+ new_session = create_session_for_notebook(
+ notebook_id=current_notebook.id, session_name=new_session_name
+ )
+ st.session_state[current_notebook.id]["active_session"] = new_session.id
+ st.rerun()
+ st.divider()
+ sessions = current_notebook.chat_sessions
+ if len(sessions) > 1:
+ st.markdown("**Other Sessions:**")
+ for session in sessions:
+ if session.id == current_session.id:
+ continue
+
+ st.markdown(
+ f"{session.title} - {humanize.naturaltime(session.updated)}"
+ )
+ if st.button(label="Load", key=f"load_session_{session.id}"):
+ st.session_state[current_notebook.id]["active_session"] = (
+ session.id
+ )
+ st.rerun()
+ with st.container(border=True):
+ request = st.chat_input("Enter your question")
+ # removing for now since it's not multi-model capable right now
+ if request:
+ response = execute_chat(
+ txt_input=request,
+ context=context,
+ current_session=current_session,
+ )
+ st.session_state[current_session.id]["messages"] = response["messages"]
+
+ for msg in st.session_state[current_session.id]["messages"][::-1]:
+ if msg.type not in ["human", "ai"]:
+ continue
+ if not msg.content:
+ continue
+
+ with st.chat_message(name=msg.type):
+ st.markdown(convert_source_references(msg.content))
+ if msg.type == "ai":
+ if st.button("πΎ New Note", key=f"render_save_{msg.id}"):
+ make_note_from_chat(
+ content=msg.content,
+ notebook_id=current_notebook.id,
+ )
+ st.rerun()
diff --git a/stream_app/consts.py b/pages/stream_app/consts.py
similarity index 100%
rename from stream_app/consts.py
rename to pages/stream_app/consts.py
diff --git a/stream_app/note.py b/pages/stream_app/note.py
similarity index 62%
rename from stream_app/note.py
rename to pages/stream_app/note.py
index 2cf063e..3dc96d4 100644
--- a/stream_app/note.py
+++ b/pages/stream_app/note.py
@@ -1,53 +1,32 @@
+from typing import Optional
+
import streamlit as st
from humanize import naturaltime
from loguru import logger
-from streamlit_monaco import st_monaco # type: ignore
from open_notebook.domain.notebook import Note
from open_notebook.graphs.multipattern import graph as pattern_graph
from open_notebook.utils import surreal_clean
+from pages.components import note_panel
from .consts import context_icons
@st.dialog("Write a Note", width="large")
-def add_note(session_id):
+def add_note(notebook_id):
note_title = st.text_input("Title")
note_content = st.text_area("Content")
if st.button("Save", key="add_note"):
logger.debug("Adding note")
note = Note(title=note_title, content=note_content, note_type="human")
note.save()
- note.add_to_notebook(st.session_state[session_id]["notebook"].id)
+ note.add_to_notebook(notebook_id)
st.rerun()
@st.dialog("Add a Source", width="large")
-def note_panel(session_id=None, note_id=None):
- if note_id:
- note: Note = Note.get(note_id)
- else:
- note: Note = Note()
-
- t_preview, t_edit = st.tabs(["Preview", "Edit"])
- with t_preview:
- st.subheader(note.title)
- st.markdown(note.content)
- with t_edit:
- note.title = st.text_input("Title", value=note.title)
- note.content = st_monaco(
- value=note.content, height="600px", language="markdown"
- )
- if st.button("Save", key=f"pn_edit_note_{note_id}"):
- logger.debug("Editing note")
- note.save()
- if not note.id:
- note.add_to_notebook(st.session_state[session_id]["notebook"].id)
- st.rerun()
- if st.button("Delete", type="primary", key=f"delete_note_{note_id}"):
- logger.debug("Deleting note")
- note.delete()
- st.rerun()
+def note_panel_dialog(note: Optional[Note] = None, notebook_id=None):
+ note_panel(note_id=note.id, notebook_id=notebook_id)
def make_note_from_chat(content, notebook_id=None):
@@ -70,7 +49,7 @@ def make_note_from_chat(content, notebook_id=None):
st.rerun()
-def note_card(session_id, note):
+def note_card(note, notebook_id):
if note.note_type == "human":
icon = "π€΅"
else:
@@ -88,13 +67,12 @@ def note_card(session_id, note):
st.caption(f"Updated: {naturaltime(note.updated)}")
if st.button("Expand", icon="π", key=f"edit_note_{note.id}"):
- note_panel(session_id, note.id)
+ note_panel_dialog(notebook_id=notebook_id, note=note)
- st.session_state[session_id]["context_config"][note.id] = context_state
+ st.session_state[notebook_id]["context_config"][note.id] = context_state
def note_list_item(note_id, score=None):
- logger.debug(note_id)
note: Note = Note.get(note_id)
if note.note_type == "human":
icon = "π€΅"
@@ -106,4 +84,4 @@ def note_list_item(note_id, score=None):
):
st.write(note.content)
if st.button("Edit Note", icon="π", key=f"x_edit_note_{note.id}"):
- note_panel(note_id=note.id)
+ note_panel_dialog(note=note)
diff --git a/stream_app/source.py b/pages/stream_app/source.py
similarity index 52%
rename from stream_app/source.py
rename to pages/stream_app/source.py
index 63c25f8..fdc56c5 100644
--- a/stream_app/source.py
+++ b/pages/stream_app/source.py
@@ -2,106 +2,47 @@ import os
from pathlib import Path
import streamlit as st
-import streamlit_scrollable_textbox as stx # type: ignore
-import yaml
from humanize import naturaltime
from loguru import logger
-from open_notebook.config import UPLOADS_FOLDER, load_default_models
+from open_notebook.config import UPLOADS_FOLDER
from open_notebook.domain.notebook import Asset, Source
from open_notebook.exceptions import UnsupportedTypeException
from open_notebook.graphs.content_processing import graph
-from open_notebook.graphs.multipattern import graph as transform_graph
from open_notebook.utils import surreal_clean
+from pages.components import source_panel
+from pages.stream_app.utils import run_patterns
from .consts import context_icons
-DEFAULT_MODELS, EMBEDDING_MODEL, SPEECH_TO_TEXT_MODEL = load_default_models()
-
-def run_patterns(input_text, patterns):
- output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns))
- return output["output"]
+# moved it here to replace it with the pipeline on 0.1.0
+def generate_toc_and_title(source) -> "Source":
+ try:
+ patterns = ["patterns/default/toc"]
+ result = run_patterns(source.full_text, patterns=patterns)
+ source.add_insight("Table of Contents", surreal_clean(result))
+ if not source.title:
+ patterns = [
+ "Based on the Table of Contents below, please provide a Title for this content, with max 15 words"
+ ]
+ output = run_patterns(result, patterns=patterns)
+ source.title = surreal_clean(output)
+ source.save()
+ return source
+ except Exception as e:
+ logger.error(f"Error summarizing source {source.id}: {str(e)}")
+ logger.exception(e)
+ raise
@st.dialog("Source", width="large")
-def source_panel(source_id):
- source: Source = Source.get(source_id)
- if not source:
- st.error("Source not found")
- return
- current_title = source.title if source.title else "No Title"
- source.title = st.text_input("Title", value=current_title)
- if source.title != current_title:
- st.toast("Saved new Title")
- source.save()
-
- process_tab, source_tab = st.tabs(["Process", "Source"])
- with process_tab:
- c1, c2 = st.columns([3, 1])
- with c1:
- title = st.empty()
- if source.title:
- title.subheader(source.title)
- if source.asset.url:
- from_src = f"from URL: {source.asset.url}"
- elif source.asset.file_path:
- from_src = f"from file: {source.asset.file_path}"
- else:
- from_src = "from text"
- st.caption(f"Created {naturaltime(source.created)}, {from_src}")
- for insight in source.insights:
- with st.expander(f"**{insight.insight_type}**"):
- st.markdown(insight.content)
- if st.button(
- "Delete", type="primary", key=f"delete_insight_{insight.id}"
- ):
- insight.delete()
- st.rerun(scope="fragment")
-
- with c2:
- with open("transformations.yaml", "r") as file:
- transformations = yaml.safe_load(file)
- for transformation in transformations["source_insights"]:
- if st.button(
- transformation["name"], help=transformation["description"]
- ):
- result = run_patterns(
- source.full_text, transformation["patterns"]
- )
- source.add_insight(
- transformation["insight_type"], surreal_clean(result)
- )
- st.rerun(scope="fragment")
-
- if st.button(
- "Embed vectors",
- icon="π¦Ύ",
- help="This will generate your embedding vectors on the database for powerful search capabilities",
- ):
- source.vectorize()
- st.success("Embedding complete")
-
- chk_delete = st.checkbox(
- "ποΈ Delete source", key=f"delete_source_{source.id}", value=False
- )
- if chk_delete:
- st.warning(
- "Source will be deleted with all its insights and embeddings"
- )
- if st.button(
- "Delete", type="primary", key=f"bt_delete_source_{source.id}"
- ):
- source.delete()
- st.rerun()
-
- with source_tab:
- st.subheader("Content")
- stx.scrollableTextbox(source.full_text, height=300)
+def source_panel_dialog(source_id):
+ source_panel(source_id)
@st.dialog("Add a Source", width="large")
-def add_source(session_id):
+def add_source(notebook_id):
source_link = None
source_file = None
source_text = None
@@ -149,9 +90,9 @@ def add_source(session_id):
title=result.get("title"),
)
source.save()
- source.add_to_notebook(st.session_state[session_id]["notebook"].id)
+ source.add_to_notebook(notebook_id)
st.write("Summarizing...")
- source.generate_toc_and_title()
+ generate_toc_and_title(source)
except UnsupportedTypeException as e:
st.warning(
"This type of content is not supported yet. If you think it should be, let us know on the project Issues's page"
@@ -170,7 +111,7 @@ def add_source(session_id):
st.rerun()
-def source_card(session_id, source):
+def source_card(source, notebook_id):
# todo: more descriptive icons
icon = "π"
@@ -188,9 +129,9 @@ def source_card(session_id, source):
f"Updated: {naturaltime(source.updated)}, **{len(source.insights)}** insights"
)
if st.button("Expand", icon="π", key=source.id):
- source_panel(source.id)
+ source_panel_dialog(source.id)
- st.session_state[session_id]["context_config"][source.id] = context_state
+ st.session_state[notebook_id]["context_config"][source.id] = context_state
def source_list_item(source_id, score=None):
@@ -207,4 +148,4 @@ def source_list_item(source_id, score=None):
st.markdown(f"**{insight.insight_type}**")
st.write(insight.content)
if st.button("Edit source", icon="π", key=f"x_edit_source_{source.id}"):
- source_panel(source_id=source.id)
+ source_panel_dialog(source_id=source.id)
diff --git a/pages/stream_app/utils.py b/pages/stream_app/utils.py
new file mode 100644
index 0000000..2caae54
--- /dev/null
+++ b/pages/stream_app/utils.py
@@ -0,0 +1,203 @@
+import re
+from datetime import datetime
+from typing import List, Union
+
+import streamlit as st
+from loguru import logger
+
+from open_notebook.database.migrate import MigrationManager
+from open_notebook.domain.models import model_manager
+from open_notebook.domain.notebook import ChatSession, Notebook
+from open_notebook.graphs.chat import ThreadState, graph
+from open_notebook.graphs.multipattern import graph as transform_graph
+from open_notebook.utils import (
+ compare_versions,
+ get_installed_version,
+ get_version_from_github,
+)
+
+
+def run_patterns(input_text, patterns):
+ output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns))
+ return output["output"]
+
+
+def version_sidebar():
+ with st.sidebar:
+ try:
+ current_version = get_installed_version("open-notebook")
+ except Exception:
+ # Fallback to reading directly from pyproject.toml
+ import tomli
+
+ with open("pyproject.toml", "rb") as f:
+ pyproject = tomli.load(f)
+ current_version = pyproject["tool"]["poetry"]["version"]
+
+ latest_version = get_version_from_github(
+ "https://www.github.com/lfnovo/open-notebook", "main"
+ )
+ st.write(f"Open Notebook: {current_version}")
+ if compare_versions(current_version, latest_version) < 0:
+ st.warning(
+ f"New version {latest_version} available. [Click here for upgrade instructions](https://github.com/lfnovo/open-notebook/blob/main/docs/SETUP.md#upgrading-open-notebook)"
+ )
+
+
+def create_session_for_notebook(notebook_id: str, session_name: str = None):
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ title = f"Chat Session {current_time}" if not session_name else session_name
+ chat_session = ChatSession(title=title)
+ chat_session.save()
+ chat_session.relate_to_notebook(notebook_id)
+ return chat_session
+
+
+def setup_stream_state(current_notebook: Notebook) -> ChatSession:
+ """
+ Sets the value of the current session_id for langgraph thread state.
+ If there is no existing thread state for this session_id, it creates a new one.
+ Finally, it acquires the existing state for the session from Langgraph state and sets it in the streamlit session state.
+ """
+ assert (
+ current_notebook is not None and current_notebook.id
+ ), "Current Notebook not selected properly"
+
+ if "context_config" not in st.session_state[current_notebook.id]:
+ st.session_state[current_notebook.id]["context_config"] = {}
+
+ current_session_id = st.session_state[current_notebook.id].get("active_session")
+
+ # gets the chat session if provided
+ chat_session: Union[ChatSession, None] = (
+ ChatSession.get(current_session_id) if current_session_id else None
+ )
+
+ # if there is no chat session, create one or get the first one
+ if not chat_session:
+ sessions: List[ChatSession] = current_notebook.chat_sessions
+ if not sessions or len(sessions) == 0:
+ logger.debug("Creating new chat session")
+ chat_session = create_session_for_notebook(current_notebook.id)
+ else:
+ logger.debug("Getting last updated session")
+ chat_session = sessions[0]
+
+ if not chat_session or chat_session.id is None:
+ raise ValueError("Problem acquiring chat session")
+ # sets the active session for the notebook
+ st.session_state[current_notebook.id]["active_session"] = chat_session.id
+
+ # gets the existing state for the session from Langgraph state
+ existing_state = graph.get_state(
+ {"configurable": {"thread_id": chat_session.id}}
+ ).values
+ if not existing_state or len(existing_state.keys()) == 0:
+ st.session_state[chat_session.id] = ThreadState(
+ messages=[], context=None, notebook=None, context_config={}
+ )
+ else:
+ st.session_state[chat_session.id] = existing_state
+
+ st.session_state[current_notebook.id]["active_session"] = chat_session.id
+ return chat_session
+
+
+def check_migration():
+ mm = MigrationManager()
+ if mm.needs_migration:
+ st.warning("The Open Notebook database needs a migration to run properly.")
+ if st.button("Run Migration"):
+ mm.run_migration_up()
+ st.success("Migration successful")
+ st.rerun()
+ st.stop()
+
+
+def check_models():
+ default_models = model_manager.defaults
+ if (
+ not default_models.default_chat_model
+ or not default_models.default_transformation_model
+ ):
+ st.warning(
+ "You don't have default chat and transformation models selected. Please, select them on the settings page."
+ )
+ st.stop()
+ elif not default_models.default_embedding_model:
+ st.warning(
+ "You don't have a default embedding model selected. Vector search will not be possible and your assistant will be less able to answer your queries. Please, select one on the settings page."
+ )
+ st.stop()
+ elif not default_models.default_speech_to_text_model:
+ st.warning(
+ "You don't have a default speech to text model selected. Your assistant will not be able to transcribe audio. Please, select one on the settings page."
+ )
+ st.stop()
+ elif not default_models.default_text_to_speech_model:
+ st.warning(
+ "You don't have a default text to speech model selected. Your assistant will not be able to generate audio and podcasts. Please, select one on the settings page."
+ )
+ st.stop()
+ elif not default_models.large_context_model:
+ st.warning(
+ "You don't have a large context model selected. Your assistant will not be able to process large documents. Please, select one on the settings page."
+ )
+ st.stop()
+
+
+def handle_error(func):
+ """Decorator for consistent error handling"""
+
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ logger.error(f"Error in {func.__name__}: {str(e)}")
+ logger.exception(e)
+ st.error(f"An error occurred: {str(e)}")
+
+ return wrapper
+
+
+def setup_page(title: str, layout="wide", sidebar_state="expanded"):
+ """Common page setup for all pages"""
+ st.set_page_config(
+ page_title=title, layout=layout, initial_sidebar_state=sidebar_state
+ )
+ check_migration()
+ check_models()
+ version_sidebar()
+
+
+def convert_source_references(text):
+ """
+ Converts source references in brackets to markdown-style links.
+
+ Matches patterns like [source_insight:id], [note:id], [source:id], or [source_embedding:id]
+ and converts them to markdown links.
+
+ Args:
+ text (str): The input text containing source references
+
+ Returns:
+ str: Text with source references converted to markdown links
+
+ Example:
+ >>> text = "Here is a reference [source_insight:abc123]"
+ >>> convert_source_references(text)
+ 'Here is a reference [source_insight:abc123](/?object_id=source_insight:abc123)'
+ """
+
+ # Pattern matches [type:id] where type can be source_insight, note, source, or source_embedding
+ pattern = r"\[((?:source_insight|note|source|source_embedding):[\w\d]+)\]"
+
+ def replace_match(match):
+ """Helper function to create the markdown link"""
+ source_ref = match.group(1) # Gets the content inside brackets
+ return f"[[{source_ref}]](/?object_id={source_ref})"
+
+ # Replace all matches in the text
+ converted_text = re.sub(pattern, replace_match, text)
+
+ return converted_text
diff --git a/prompts/chat.jinja b/prompts/chat.jinja
index 0ba3471..dab2eda 100644
--- a/prompts/chat.jinja
+++ b/prompts/chat.jinja
@@ -5,9 +5,8 @@ You are a cognitive study assistant that helps users research and learn by engag
- Access to project information and selected documents (CONTEXT)
- Can engage in natural dialogue while maintaining academic rigor
-# FORMULATE YOUR DATA
-- Generate your answer based on the CONTEXT information
-- Ensure that your response is accurate and relevant to the user's query
+# YOUR OPERATING METHOD
+Whenever a user asks you a question, you need to identify the query context and the user intent. The user might be continuing a previous conversation or asking a new question. Looking at the CONTEXT will probably give you a hint of what the user is looking for. Once you identify the user intent, formulate your answer accordingly paying attention to the CITING INSTRUCTIONS below.
{% if notebook %}
# PROJECT INFORMATION
@@ -18,5 +17,26 @@ You are a cognitive study assistant that helps users research and learn by engag
{% if context %}
# CONTEXT
+The user has selected this context to help you with your response:
+
{{context}}
-{% endif %}
\ No newline at end of file
+{% endif %}
+
+# CITING INSTRUCTIONS
+
+If your answer is based off of any item in the context, it's very important that your response contains references to the searched documents so the user can follow-up and read more about the topic. The way you do that is by adding the id of the specific document in between brackets like this: [document_id].
+
+## EXAMPLE
+
+User: Can you tell me more about the concept of "Deep Learning"?
+
+Assistant: Deep learning is a subset of machine learning in artificial intelligence (AI) that enables networks to learn unsupervised from unstructured or unlabeled data. [note:iuiodadalknda]. It can also be categorized into three main types: supervised, unsupervised, and reinforcement learning. [insight:adadadadadadad].
+
+Please note, "note:iuiodadalknda" and "insight:adadadadadadad" are examples of document IDs with different prefixes. You should not make up document IDs or copy the IDs from this example. You should use the IDs of the documents that you have access to through the search tool.
+
+## IMPORTANT
+
+- Do not make up documents or document ids. Only use the ids of the documents that you have access through the query you made.
+- The ID is composed of the type of document and a random string, such as "source:randomstring", "note:randomstring", or "insight:randomstring". There are various types of documents, including notes, insights, and sources. **Always use the complete ID exactly as it is provided, including its type prefix. Do not add, remove, or modify any part of the ID.**
+- Do not assume or change the type prefix of any document ID. If a document ID is "note:xyz", use it exactly as "note:xyz". Do not change it to "source:xyz" or any other variation.
+- **Use document IDs exactly as they are returned from the search tool. Do not add any prefixes or modify them in any way.**
diff --git a/prompts/doc_query.jinja b/prompts/doc_query.jinja
deleted file mode 100644
index 4212bff..0000000
--- a/prompts/doc_query.jinja
+++ /dev/null
@@ -1,26 +0,0 @@
-
-# BACKGROUND
-
-Your are a cognitive assistant that helps me study and research.
-
-# OUR WORKING FRAMEWORK
-
-You have access to some information about the project I am working on
-as well as the content of a specific item I am interested about.
-
-Your goal is to respond to the question using purely the content in your CONTEXT.
-
-If the content in CONTEXT is not enough to answer the question, do not make up any information and just reply that you can't answer that.
-Kindly tell the user what sort of things you'd be able to talk about.
-
-# PROJECT INFO
-
-{{ notebook }}
-
-# CONTENT
-
-{{ doc_content }}
-
-# QUESTION
-
-{{ question}}
\ No newline at end of file
diff --git a/prompts/recursive_toc.jinja b/prompts/patterns/default/toc.jinja
similarity index 55%
rename from prompts/recursive_toc.jinja
rename to prompts/patterns/default/toc.jinja
index b92512b..23b84f0 100644
--- a/prompts/recursive_toc.jinja
+++ b/prompts/patterns/default/toc.jinja
@@ -8,17 +8,8 @@ Analyze the provided content and create a Table of Contents:
- Captures the core topics included in the text
- Gives a small description of what is covered
-# INSTRUCTIONS FOR LARGE DOCUMENTS
+# INPUT
-If you see a PREVIOUS TOC section below, it means that this request is a continuation of a previous request. Most likely to handle context length issues.
-Every time, you should replace the previous toc with the new one, and append the new content to the previous content.
+{{input_text}}
-{% if toc %}
-# PREVIOUS TOC
-
-{{toc}}
-{% endif %}
-
-# CONTENT
-
-{{content}}
+# OUTPUT
\ No newline at end of file
diff --git a/prompts/rag.jinja b/prompts/rag.jinja
new file mode 100644
index 0000000..3d8d057
--- /dev/null
+++ b/prompts/rag.jinja
@@ -0,0 +1,59 @@
+# SYSTEM ROLE
+You are a cognitive study assistant that helps users research and learn by engaging in focused discussions about documents in their workspace.
+
+You have access to a search tool that you can use in order to reply to the user query.
+
+The tool accepts 2 arrays as parameters:
+
+- keyword_searches: List[str] - A list of search terms to search for using keyword search.
+- vector_searches: List[str] - A list of search terms to search for using vector search.
+
+It's very important that your response contains references to the searched documents so the user can follow-up and read more about the topic. The way you do that is by adding the id of the specific document in between brackets like this: [document_id].
+
+# EXAMPLE
+
+User: Can you tell me more about the concept of "Deep Learning"?
+
+Assistant: Deep learning is a subset of machine learning in artificial intelligence (AI) that enables networks to learn unsupervised from unstructured or unlabeled data. [note:iuiodadalknda]. It can also be categorized into three main types: supervised, unsupervised, and reinforcement learning. [insight:adadadadadadad].
+
+Please note, "note:iuiodadalknda" and "insight:adadadadadadad" are examples of document IDs with different prefixes. You should not make up document IDs or copy the IDs from this example. You should use the IDs of the documents that you have access to through the search tool.
+
+# IMPORTANT
+
+- Do not make up documents or document ids. Only use the ids of the documents that you have access through the query you made.
+- The ID is composed of the type of document and a random string, such as "source:randomstring", "note:randomstring", or "insight:randomstring". There are various types of documents, including notes, insights, and sources. **Always use the complete ID exactly as it is provided, including its type prefix. Do not add, remove, or modify any part of the ID.**
+- Do not assume or change the type prefix of any document ID. If a document ID is "note:xyz", use it exactly as "note:xyz". Do not change it to "source:xyz" or any other variation.
+- **Use document IDs exactly as they are returned from the search tool. Do not add any prefixes or modify them in any way.**
+
+
+{#
+You are a cognitive study assistant designed to help users research and learn by engaging in focused discussions about documents in their workspace. Your primary goal is to provide informative, accurate responses to user queries while properly citing relevant documents from the available search tool.
+
+To answer this question effectively, you have access to a search tool with the following parameters:
+- keyword_searches: List[str] - A list of search terms for keyword search
+- vector_searches: List[str] - A list of search terms for vector search
+
+Follow these steps to formulate your response:
+
+1. Analyze the user's question and determine appropriate search terms.
+2. Use the search tool to find relevant information.
+3. Carefully review the search results, paying close attention to document IDs and content relevance.
+4. Compose a clear, informative response that directly addresses the user's question.
+5. Include relevant document citations using the exact document IDs provided by the search tool.
+6. Review your response for accuracy and relevance before delivering it to the user.
+
+Important guidelines:
+- Always use the complete document ID as provided by the search tool, including its type prefix (e.g., "note:", "insight:", "source:").
+- Do not make up or modify document IDs in any way.
+- Ensure that each citation is directly relevant to the information it supports.
+- Prioritize accuracy and relevance in your search strategy and response composition.
+
+Before composing your final response, wrap your thought process in tags to analyze the question, plan your search strategy, and evaluate the search results. This will help ensure that you retrieve the most relevant information and use the correct document IDs in your citations. Include the following steps:
+a. Analyze the question and identify key concepts
+b. Plan search strategy (both keyword and vector searches)
+c. Evaluate search results and note relevant document IDs
+d. Outline the main points for the response
+
+Your final response should be conversational in tone, directly addressing the user's question while seamlessly incorporating document citations. Use square brackets with the full document ID for each citation, like this: [document_id].
+
+Remember, the quality and accuracy of your response, including proper document citations, are crucial for helping the user in their research and learning process. #}
\ No newline at end of file
diff --git a/prompts/summarize.jinja b/prompts/summarize.jinja
deleted file mode 100644
index f8b65ab..0000000
--- a/prompts/summarize.jinja
+++ /dev/null
@@ -1,33 +0,0 @@
-
-# SYSTEM ROLE
-You are a content summarization assistant that creates dense, information-rich summaries optimized for machine understanding. Your summaries should capture key concepts with minimal words while maintaining complete, clear sentences.
-
-# TASK
-Analyze the provided content and create a summary that:
-- Captures the core concepts and key information
-- Uses clear, direct language
-- Maintains context from any previous summaries
-- Includes relevant topics/tags
-- Creates an appropriate title
-
-# OUTPUT SCHEMA
-{'summary': {'type': 'string'},
- 'topics': {'items': {'type': 'string'}, 'type': 'array'},
- 'title': {'type': 'string'}}
-
-# OUTPUT EXAMPLE
-{
- "title": "The title of the content",
- "topics": ["topic1", "topic2"],
- "summary": "The summary of the content"
-}
-
-# CONTENT
-
-{{content}}
-
-{% if summary %}
-# PREVIOUS SUMMARY
-
-{{summary}}
-{% endif %}
diff --git a/pyproject.toml b/pyproject.toml
index 811bec9..26f5b06 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "open-notebook"
-version = "0.0.8"
+version = "0.0.9"
description = "An open source implementation of a research assistant, inspired by Google Notebook LM"
authors = ["Luis Novo "]
license = "MIT"
diff --git a/stream_app/chat.py b/stream_app/chat.py
deleted file mode 100644
index 00d8bd5..0000000
--- a/stream_app/chat.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import streamlit as st
-from langchain_core.runnables import RunnableConfig
-
-from open_notebook.domain.notebook import Note, Source
-from open_notebook.graphs.chat import graph as chat_graph
-from open_notebook.plugins.podcasts import PodcastConfig
-from open_notebook.utils import token_count
-from stream_app.note import make_note_from_chat
-
-
-# todo: build a smarter, more robust context manager function
-def build_context(session_id):
- st.session_state[session_id]["context"] = dict(note=[], source=[])
-
- for id, status in st.session_state[session_id]["context_config"].items():
- if not id:
- continue
-
- item_type, item_id = id.split(":")
- if item_type not in ["note", "source"]:
- continue
-
- if "not in" in status:
- continue
-
- if item_type == "note":
- item: Note = Note.get(id)
- elif item_type == "source":
- item: Source = Source.get(id)
- else:
- continue
-
- if not item:
- continue
- if "summary" in status:
- st.session_state[session_id]["context"][item_type] += [
- item.get_context(context_size="short")
- ]
- elif "full content" in status:
- st.session_state[session_id]["context"][item_type] += [
- item.get_context(context_size="long")
- ]
-
- return st.session_state[session_id]["context"]
-
-
-def execute_chat(txt_input, session_id):
- current_state = st.session_state[session_id]
- current_state["messages"] += [txt_input]
- result = chat_graph.invoke(
- input=current_state,
- config=RunnableConfig(configurable={"thread_id": session_id}),
- )
- return result
-
-
-def chat_sidebar(session_id):
- context = build_context(session_id=session_id)
- tokens = token_count(str(context) + str(st.session_state[session_id]["messages"]))
- chat_tab, podcast_tab = st.tabs(["Chat", "Podcast"])
- with st.expander(f"Context ({tokens} tokens), {len(str(context))} chars"):
- st.json(context)
- with podcast_tab:
- with st.container(border=True):
- podcast_configs = PodcastConfig.get_all()
- podcast_config_names = [pd.name for pd in podcast_configs]
- if len(podcast_configs) == 0:
- st.warning("No podcast configurations found")
- else:
- template = st.selectbox("Pick a template", podcast_config_names)
- selected_template = next(
- filter(lambda x: x.name == template, podcast_configs)
- )
- episode_name = st.text_input("Episode Name")
- instructions = st.text_area(
- "Instructions", value=selected_template.user_instructions
- )
- if len(context.get("note", [])) + len(context.get("source", [])) == 0:
- st.warning(
- "No notes or sources found in context. You don't want a boring podcast, right? So, add some context first."
- )
- else:
- if st.button("Generate"):
- with st.spinner("Go grab a coffee, almost there..."):
- selected_template.generate_episode(
- episode_name=episode_name,
- text=context,
- instructions=instructions,
- )
- st.success("Episode generated successfully")
- st.page_link("pages/5_ποΈ_Podcasts.py", label="ποΈ Go to Podcasts")
- with chat_tab:
- with st.container(border=True):
- request = st.chat_input("Enter your question")
- # removing for now since it's not multi-model capable right now
- st.caption(f"Total tokens: {tokens}")
- if request:
- response = execute_chat(txt_input=request, session_id=session_id)
- st.session_state[session_id]["messages"] = response["messages"]
-
- for msg in st.session_state[session_id]["messages"][::-1]:
- if msg.type not in ["human", "ai"]:
- continue
- if not msg.content:
- continue
-
- with st.chat_message(name=msg.type):
- st.write(msg.content)
- if msg.type == "ai":
- if st.button("πΎ New Note", key=f"render_save_{msg.id}"):
- make_note_from_chat(
- content=msg.content,
- notebook_id=st.session_state[session_id]["notebook"].id,
- )
- st.rerun()
diff --git a/stream_app/utils.py b/stream_app/utils.py
deleted file mode 100644
index 55d9db1..0000000
--- a/stream_app/utils.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import streamlit as st
-
-from open_notebook.graphs.chat import ThreadState, graph
-from open_notebook.utils import (
- compare_versions,
- get_installed_version,
- get_version_from_github,
-)
-
-
-def version_sidebar():
- with st.sidebar:
- try:
- current_version = get_installed_version(
- "open-notebook"
- ) # Note the hyphen instead of underscore
- except Exception:
- # Fallback to reading directly from pyproject.toml
- import tomli
-
- with open("pyproject.toml", "rb") as f:
- pyproject = tomli.load(f)
- current_version = pyproject["tool"]["poetry"]["version"]
-
- latest_version = get_version_from_github(
- "https://www.github.com/lfnovo/open-notebook", "main"
- )
- st.write(f"Open Notebook: {current_version}")
- if compare_versions(current_version, latest_version) < 0:
- st.warning(
- f"New version {latest_version} available. [Click here for upgrade instructions](https://github.com/lfnovo/open-notebook/blob/main/docs/SETUP.md#upgrading-open-notebook)"
- )
-
-
-def setup_stream_state(session_id) -> None:
- """
- Sets the value of the current session_id for langgraph thread state.
- If there is no existing thread state for this session_id, it creates a new one.
- """
- existing_state = graph.get_state({"configurable": {"thread_id": session_id}}).values
- if len(existing_state.keys()) == 0:
- st.session_state[session_id] = ThreadState(
- messages=[], context=None, notebook=None, context_config={}, response=None
- )
- else:
- st.session_state[session_id] = existing_state
- st.session_state["active_session"] = session_id
- st.session_state["active_session"] = session_id
diff --git a/transformations.yaml b/transformations.yaml
index b0d3a95..435fb3c 100644
--- a/transformations.yaml
+++ b/transformations.yaml
@@ -16,6 +16,11 @@ source_insights:
description: "Create a dense representation of the content"
patterns:
- patterns/default/makeitdense
+ - name: "Table of Contents"
+ insight_type: "Table of Contents"
+ description: "Analyzes the content and returns a ToC"
+ patterns:
+ - patterns/default/analyze_paper
- name: "Analyze Paper"
insight_type: "Paper Analysis"
description: "Analyze the paper and provide a quick summary"