diff --git a/open_notebook/domain.py b/open_notebook/domain.py index 44c1fef..4011473 100644 --- a/open_notebook/domain.py +++ b/open_notebook/domain.py @@ -11,7 +11,8 @@ from open_notebook.exceptions import ( InvalidInputError, NotFoundError, ) -from open_notebook.graphs.summary import graph as summarizer +from open_notebook.graphs.multipattern import graph as pattern_graph +from open_notebook.graphs.recursive_toc import graph as toc_graph from open_notebook.repository import ( repo_create, repo_delete, @@ -239,8 +240,7 @@ class Source(ObjectModel): def vectorize(self) -> None: try: - full_text = self.full_text - if not full_text: + if not self.full_text: return chunks = split_text( self.full_text, @@ -306,15 +306,20 @@ class Source(ObjectModel): logger.error(f"Error adding insight to source {self.id}: {str(e)}") raise DatabaseOperationError(e) - def summarize(self) -> "Source": + def generate_toc_and_title(self) -> "Source": try: config = RunnableConfig(configurable=dict(thread_id=self.id)) - result = summarizer.invoke({"content": self.full_text}, config=config)[ - "output" + result = toc_graph.invoke({"content": self.full_text}, config=config) + logger.warning(result["toc"]) + self.add_insight("Table of Contents", surreal_clean(result["toc"])) + transformations = [ + "Based on the Table of Contents below, please provide a Title for this content, with max 15 words" ] - self.add_insight("summary", surreal_clean(result.summary)) - self.title = surreal_clean(result.title) - self.topics = result.topics + output = pattern_graph.invoke( + dict(content_stack=[result["toc"]], transformations=transformations) + ) + logger.warning(output["output"]) + self.title = surreal_clean(output["output"]) self.save() return self except Exception as e: diff --git a/open_notebook/graphs/recursive_toc.py b/open_notebook/graphs/recursive_toc.py new file mode 100644 index 0000000..9cffc5d --- /dev/null +++ b/open_notebook/graphs/recursive_toc.py @@ -0,0 +1,78 @@ +import os +from typing import List, Literal + +from langchain_core.runnables import ( + RunnableConfig, +) +from langgraph.graph import END, START, StateGraph +from typing_extensions import TypedDict + +from open_notebook.graphs.utils import run_pattern +from open_notebook.utils import split_text + + +class TocState(TypedDict): + chunks: List[str] + content: str + toc: str + + +def build_chunks(state: TocState) -> dict: + """ + Split the input text into chunks. + """ + return { + "chunks": split_text( + state["content"], + chunk=int(os.environ.get("SUMMARY_CHUNK_SIZE", 200000)), + overlap=int(os.environ.get("SUMMARY_CHUNK_OVERLAP", 1000)), + ) + } + + +def setup_next_chunk(state: TocState) -> dict: + """ + Move the next item in the chunk to the processing area + """ + state["content"] = state["chunks"].pop(0) + return {"chunks": state["chunks"], "content": state["content"]} + + +def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: ignore + """ + Checks whether there are more chunks to process. + """ + if len(state["chunks"]) > 0: + return "get_chunk" + return END + + +def call_model(state: TocState, config: RunnableConfig) -> dict: + model_name = config.get("configurable", {}).get( + "model_name", os.environ.get("SUMMARIZATION_MODEL") + ) + return { + "toc": run_pattern( + pattern_name="recursive_toc", + model_name=model_name, + state=state, + ).content + } + + +agent_state = StateGraph(TocState) +agent_state.add_node("setup_chunk", build_chunks) +agent_state.add_edge(START, "setup_chunk") +agent_state.add_conditional_edges( + "setup_chunk", + chunk_condition, +) +agent_state.add_node("get_chunk", setup_next_chunk) +agent_state.add_node("agent", call_model) +agent_state.add_edge("get_chunk", "agent") +agent_state.add_conditional_edges( + "agent", + chunk_condition, +) + +graph = agent_state.compile() diff --git a/prompts/recursive_toc.jinja b/prompts/recursive_toc.jinja new file mode 100644 index 0000000..b92512b --- /dev/null +++ b/prompts/recursive_toc.jinja @@ -0,0 +1,24 @@ + +# SYSTEM ROLE +You are a content analysis assistant that reads through documents and provides a Table of Contents (ToC) to help users identify what the document covers more easily. +Your ToC should capture all major topics and transitions in the content and should mention them in the order theh appear. + +# TASK +Analyze the provided content and create a Table of Contents: +- Captures the core topics included in the text +- Gives a small description of what is covered + +# INSTRUCTIONS FOR LARGE DOCUMENTS + +If you see a PREVIOUS TOC section below, it means that this request is a continuation of a previous request. Most likely to handle context length issues. +Every time, you should replace the previous toc with the new one, and append the new content to the previous content. + +{% if toc %} +# PREVIOUS TOC + +{{toc}} +{% endif %} + +# CONTENT + +{{content}} diff --git a/stream_app/chat.py b/stream_app/chat.py index 05759d6..0c68b41 100644 --- a/stream_app/chat.py +++ b/stream_app/chat.py @@ -3,7 +3,7 @@ from langchain_core.runnables import RunnableConfig from open_notebook.domain import Note, Source from open_notebook.graphs.chat import graph as chat_graph -from open_notebook.utils import token_cost, token_count +from open_notebook.utils import token_count # todo: build a smarter, more robust context manager function @@ -56,11 +56,11 @@ def execute_chat(txt_input, session_id): # seria bom ter um total de tokens no admin em algum lugar def chat_sidebar(session_id): context = build_context(session_id=session_id) - tokens = token_count(str(context)) - cost = token_cost(tokens) + tokens = token_count(str(context) + str(st.session_state[session_id]["messages"])) with st.container(border=True): request = st.chat_input("Enter your question") - st.caption(f"Total tokens: {tokens}, cost: ${cost:.4f}") + # removing for now since it's not multi-model capable right now + st.caption(f"Total tokens: {tokens}") if request: response = execute_chat(txt_input=request, session_id=session_id) st.session_state[session_id]["messages"] = response["messages"] diff --git a/stream_app/source.py b/stream_app/source.py index 1232268..bd30713 100644 --- a/stream_app/source.py +++ b/stream_app/source.py @@ -132,7 +132,7 @@ def add_source(session_id): source.save() source.add_to_notebook(st.session_state[session_id]["notebook"].id) st.write("Summarizing...") - source.summarize() + source.generate_toc_and_title() st.rerun()