enable multiple chat sessions

This commit is contained in:
LUIS NOVO 2024-11-04 15:08:14 -03:00
parent 3be1ecae8a
commit 0f2216207b
8 changed files with 276 additions and 133 deletions

View file

@ -1,4 +1,11 @@
REMOVE FUNCTION fn::vector_search;
DEFINE TABLE IF NOT EXISTS chat_session SCHEMALESS;
DEFINE TABLE IF NOT EXISTS refers_to
TYPE RELATION
FROM chat_session TO notebook;
REMOVE FUNCTION IF EXISTS fn::vector_search;
DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_count: int, $sources: bool, $show_notes: bool, $min_similarity: float) {
let $source_embedding_search =
@ -16,7 +23,6 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_cou
)}
ELSE { [] };
-- Busca em source_insight com threshold
let $source_insight_search =
IF $sources {(
SELECT
@ -67,10 +73,10 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_cou
};
REMOVE FUNCTION fn::text_search;
REMOVE FUNCTION IF EXISTS fn::text_search;
DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) {
DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) {
let $source_title_search =
IF $sources {(

View file

@ -1,3 +1,8 @@
REMOVE TABLE IF EXISTS chat_session;
REMOVE TABLE IF EXISTS refers_to;
REMOVE FUNCTION fn::vector_search;

View file

@ -1,11 +1,9 @@
import os
from typing import Any, ClassVar, Dict, List, Literal, Optional
from loguru import logger
from pydantic import BaseModel, Field, field_validator
from open_notebook.database.repository import (
repo_create,
repo_query,
)
from open_notebook.domain.base import ObjectModel
@ -68,6 +66,27 @@ class Notebook(ObjectModel):
logger.exception(e)
raise DatabaseOperationError(e)
@property
def chat_sessions(self) -> List["ChatSession"]:
try:
srcs = repo_query(f"""
select * from (
select
<- chat_session as chat_session
from refers_to
where out={self.id}
fetch chat_session
)
order by chat_session.updated desc
""")
return (
[ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
)
except Exception as e:
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
class Asset(BaseModel):
file_path: Optional[str] = None
@ -99,6 +118,22 @@ class Source(ObjectModel):
else:
return dict(id=self.id, title=self.title, insights=self.insights)
@property
def embedded_chunks(self) -> int:
try:
result = repo_query(
f"""
select count() as chunks from source_embedding where source={self.id} GROUP ALL
"""
)
if len(result) == 0:
return 0
return result[0]["chunks"]
except Exception as e:
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
@property
def insights(self) -> List[SourceInsight]:
try:
@ -118,24 +153,6 @@ class Source(ObjectModel):
raise InvalidInputError("Notebook ID must be provided")
return self.relate("reference", notebook_id)
def save_chunks(self, text: str) -> None:
if not text:
raise InvalidInputError("Text cannot be empty")
try:
chunks = split_text(text, chunk=500000, overlap=1000)
logger.debug(f"Split into {len(chunks)} chunks")
for i, chunk in enumerate(chunks):
logger.debug(f"Saving chunk {i}")
data = {"source": self.id, "order": i, "content": surreal_clean(chunk)}
repo_create(
"source_chunk",
data,
)
except Exception as e:
logger.exception(e)
logger.error(f"Error saving chunks for source {self.id}: {str(e)}")
raise DatabaseOperationError(e)
def vectorize(self) -> None:
EMBEDDING_MODEL = model_manager.embedding_model
@ -144,8 +161,6 @@ class Source(ObjectModel):
return
chunks = split_text(
self.full_text,
chunk=int(os.environ.get("EMBEDDING_CHUNK_SIZE", 1000)),
overlap=int(os.environ.get("EMBEDDING_CHUNK_OVERLAP", 1000)),
)
logger.debug(f"Split into {len(chunks)} chunks")
@ -166,26 +181,26 @@ class Source(ObjectModel):
logger.exception(e)
raise DatabaseOperationError(e)
@classmethod
def search(cls, query: str) -> List[Dict[str, Any]]:
if not query:
raise InvalidInputError("Search query cannot be empty")
try:
result = repo_query(
"""
SELECT * omit full_text
FROM source
WHERE string::lowercase(title) CONTAINS $query or title @@ $query
OR string::lowercase(summary) CONTAINS $query or summary @@ $query
OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
""",
{"query": query},
)
return result
except Exception as e:
logger.error(f"Error searching sources: {str(e)}")
logger.exception(e)
raise DatabaseOperationError("Failed to search sources")
# @classmethod
# def search(cls, query: str) -> List[Dict[str, Any]]:
# if not query:
# raise InvalidInputError("Search query cannot be empty")
# try:
# result = repo_query(
# """
# SELECT * omit full_text
# FROM source
# WHERE string::lowercase(title) CONTAINS $query or title @@ $query
# OR string::lowercase(summary) CONTAINS $query or summary @@ $query
# OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
# """,
# {"query": query},
# )
# return result
# except Exception as e:
# logger.error(f"Error searching sources: {str(e)}")
# logger.exception(e)
# raise DatabaseOperationError("Failed to search sources")
def add_insight(self, insight_type: str, content: str) -> Any:
EMBEDDING_MODEL = model_manager.embedding_model
@ -246,6 +261,16 @@ class Note(ObjectModel):
return self.content
class ChatSession(ObjectModel):
table_name: ClassVar[str] = "chat_session"
title: Optional[str] = None
def relate_to_notebook(self, notebook_id: str) -> Any:
if not notebook_id:
raise InvalidInputError("Notebook ID must be provided")
return self.relate("refers_to", notebook_id)
def text_search(keyword: str, results: int, source: bool = True, note: bool = True):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")
@ -263,18 +288,6 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
raise DatabaseOperationError(e)
# def hybrid_search(
# keyword_search: List[str],
# embed_search: List[str],
# results: int = 50,
# source: bool = True,
# note: bool = True,
# ):
# EMBEDDING_MODEL = model_manager.embedding_model
# embed1_vector = EMBEDDING_MODEL.embed(embed1) if embed1 else None
# todo: mover o embedding pra ca
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")

View file

@ -10,11 +10,14 @@ from pages.stream_app.utils import setup_page, setup_stream_state
setup_page("📒 Open Notebook")
def notebook_header(current_notebook):
def notebook_header(current_notebook: Notebook):
"""
Defines the header of the notebook page, including the ability to edit the notebook's name and description.
"""
c1, c2, c3 = st.columns([8, 2, 2])
c1.header(current_notebook.name)
if c2.button("Back to the list", icon="🔙"):
st.session_state["current_notebook"] = None
st.session_state["current_notebook_id"] = None
st.rerun()
if c3.button("Refresh", icon="🔄"):
@ -49,20 +52,20 @@ def notebook_header(current_notebook):
st.toast("Notebook unarchived", icon="🗃️")
if c3.button("Delete forever", type="primary", icon="☠️"):
current_notebook.delete()
st.session_state["current_notebook"] = None
st.session_state["current_notebook_id"] = None
st.rerun()
def notebook_page(current_notebook_id):
current_notebook: Notebook = Notebook.get(current_notebook_id)
if not current_notebook:
st.error("Notebook not found")
return
if current_notebook_id not in st.session_state.keys():
st.session_state[current_notebook_id] = current_notebook
def notebook_page(current_notebook: Notebook):
# Guarantees that we have an entry for this notebook in the session state
if current_notebook.id not in st.session_state:
st.session_state[current_notebook.id] = {"notebook": current_notebook}
# sets up the active session
current_session = setup_stream_state(
current_notebook=current_notebook,
)
session_id = st.session_state["active_session"]
st.session_state[session_id]["notebook"] = current_notebook
sources = current_notebook.sources
notes = current_notebook.notes
@ -74,18 +77,18 @@ def notebook_page(current_notebook_id):
with sources_tab:
with st.container(border=True):
if st.button("Add Source", icon=""):
add_source(session_id)
add_source(current_notebook.id)
for source in sources:
source_card(session_id=session_id, source=source)
source_card(source=source, notebook_id=current_notebook.id)
with notes_tab:
with st.container(border=True):
if st.button("Write a Note", icon="📝"):
add_note(session_id)
add_note(current_notebook.id)
for note in notes:
note_card(session_id=session_id, note=note)
note_card(note=note, notebook_id=current_notebook.id)
with chat_tab:
chat_sidebar(session_id=session_id)
chat_sidebar(current_notebook=current_notebook, current_session=current_session)
def notebook_list_item(notebook):
@ -96,40 +99,50 @@ def notebook_list_item(notebook):
)
st.write(notebook.description)
if st.button("Open", key=f"open_notebook_{notebook.id}"):
setup_stream_state(notebook.id)
st.session_state["current_notebook"] = notebook.id
st.session_state["current_notebook_id"] = notebook.id
st.rerun()
if "current_notebook" not in st.session_state:
st.session_state["current_notebook"] = None
if "current_notebook_id" not in st.session_state:
st.session_state["current_notebook_id"] = None
if st.session_state["current_notebook"]:
notebook_page(st.session_state["current_notebook"])
# todo: get the notebook, check if it exists and if it's archived
if st.session_state["current_notebook_id"]:
current_notebook: Notebook = Notebook.get(st.session_state["current_notebook_id"])
if not current_notebook:
st.error("Notebook not found")
st.stop()
notebook_page(current_notebook)
st.stop()
st.title("📒 My Notebooks")
st.caption("Here are all your notebooks")
st.caption(
"Notebooks are a great way to organize your thoughts, ideas, and sources. You can create notebooks for different research topics and projects, to create new articles, etc. "
)
with st.expander(" **New Notebook**"):
new_notebook_title = st.text_input("New Notebook Name")
new_notebook_description = st.text_area(
"Description",
placeholder="Explain the purpose of this notebook. The more details the better.",
)
if st.button("Create a new Notebook", icon=""):
notebook = Notebook(
name=new_notebook_title, description=new_notebook_description
)
notebook.save()
st.toast("Notebook created successfully", icon="📒")
notebooks = Notebook.get_all(order_by="updated desc")
archived_notebooks = [nb for nb in notebooks if nb.archived]
for notebook in notebooks:
if notebook.archived:
continue
notebook_list_item(notebook)
with st.expander(" **New Notebook**"):
new_notebook_title = st.text_input("New Notebook Name")
new_notebook_description = st.text_area("Description")
if st.button("Create a new Notebook", icon=""):
notebook = Notebook(
name=new_notebook_title, description=new_notebook_description
)
notebook.save()
st.rerun()
archived_notebooks = [nb for nb in notebooks if nb.archived]
if len(archived_notebooks) > 0:
with st.expander(f"**🗃️ {len(archived_notebooks)} archived Notebooks**"):
st.write(" Archived Notebooks can still be accessed and used in search.")
for notebook in archived_notebooks:
notebook_list_item(notebook)

View file

@ -1,19 +1,21 @@
import humanize
import streamlit as st
from langchain_core.runnables import RunnableConfig
from open_notebook.domain.notebook import Note, Source
from open_notebook.domain.notebook import ChatSession, Note, Notebook, Source
from open_notebook.graphs.chat import graph as chat_graph
from open_notebook.plugins.podcasts import PodcastConfig
from open_notebook.utils import token_count
from pages.stream_app.utils import create_session_for_notebook
from .note import make_note_from_chat
# todo: build a smarter, more robust context manager function
def build_context(session_id):
st.session_state[session_id]["context"] = dict(note=[], source=[])
def build_context(notebook_id):
st.session_state[notebook_id]["context"] = dict(note=[], source=[])
for id, status in st.session_state[session_id]["context_config"].items():
for id, status in st.session_state[notebook_id]["context_config"].items():
if not id:
continue
@ -24,6 +26,7 @@ def build_context(session_id):
if "not in" in status:
continue
# todo: there is problably a better way to handle this
if item_type == "note":
item: Note = Note.get(id)
elif item_type == "source":
@ -34,30 +37,33 @@ def build_context(session_id):
if not item:
continue
if "summary" in status:
st.session_state[session_id]["context"][item_type] += [
st.session_state[notebook_id]["context"][item_type] += [
item.get_context(context_size="short")
]
elif "full content" in status:
st.session_state[session_id]["context"][item_type] += [
st.session_state[notebook_id]["context"][item_type] += [
item.get_context(context_size="long")
]
return st.session_state[session_id]["context"]
return st.session_state[notebook_id]["context"]
def execute_chat(txt_input, session_id):
current_state = st.session_state[session_id]
def execute_chat(txt_input, current_session):
current_state = st.session_state[current_session.id]
current_state["messages"] += [txt_input]
result = chat_graph.invoke(
input=current_state,
config=RunnableConfig(configurable={"thread_id": session_id}),
config=RunnableConfig(configurable={"thread_id": current_session.id}),
)
current_session.save()
return result
def chat_sidebar(session_id):
context = build_context(session_id=session_id)
tokens = token_count(str(context) + str(st.session_state[session_id]["messages"]))
def chat_sidebar(current_notebook: Notebook, current_session: ChatSession):
context = build_context(notebook_id=current_notebook.id)
tokens = token_count(
str(context) + str(st.session_state[current_session.id]["messages"])
)
chat_tab, podcast_tab = st.tabs(["Chat", "Podcast"])
with st.expander(f"Context ({tokens} tokens), {len(str(context))} chars"):
st.json(context)
@ -91,15 +97,64 @@ def chat_sidebar(session_id):
st.success("Episode generated successfully")
st.page_link("pages/5_🎙_Podcasts.py", label="🎙️ Go to Podcasts")
with chat_tab:
with st.expander(
f"**Session:** {current_session.title} - {humanize.naturaltime(current_session.updated)}"
):
new_session_name = st.text_input(
"Current Session",
key="new_session_name",
value=current_session.title,
)
c1, c2 = st.columns(2)
if c1.button("Rename", key="rename_session"):
current_session.title = new_session_name
current_session.save()
st.rerun()
if c2.button("Delete", key="delete_session_1"):
current_session.delete()
st.session_state[current_notebook.id]["active_session"] = None
st.rerun()
st.divider()
new_session_name = st.text_input(
"New Session Name",
key="new_session_name_f",
placeholder="Enter a name for the new session...",
)
st.caption("If no name provided, we'll use the current date.")
if st.button("Create New Session", key="create_new_session"):
new_session = create_session_for_notebook(
notebook_id=current_notebook.id, session_name=new_session_name
)
st.session_state[current_notebook.id]["active_session"] = new_session.id
st.rerun()
st.divider()
sessions = current_notebook.chat_sessions
if len(sessions) > 1:
st.markdown("**Other Sessions:**")
for session in sessions:
if session.id == current_session.id:
continue
st.markdown(
f"{session.title} - {humanize.naturaltime(session.updated)}"
)
if st.button(label="Load", key=f"load_session_{session.id}"):
st.session_state[current_notebook.id]["active_session"] = (
session.id
)
st.rerun()
with st.container(border=True):
request = st.chat_input("Enter your question")
# removing for now since it's not multi-model capable right now
st.caption(f"Total tokens: {tokens}")
if request:
response = execute_chat(txt_input=request, session_id=session_id)
st.session_state[session_id]["messages"] = response["messages"]
response = execute_chat(
txt_input=request,
current_session=current_session,
)
st.session_state[current_session.id]["messages"] = response["messages"]
for msg in st.session_state[session_id]["messages"][::-1]:
for msg in st.session_state[current_session.id]["messages"][::-1]:
if msg.type not in ["human", "ai"]:
continue
if not msg.content:
@ -111,6 +166,6 @@ def chat_sidebar(session_id):
if st.button("💾 New Note", key=f"render_save_{msg.id}"):
make_note_from_chat(
content=msg.content,
notebook_id=st.session_state[session_id]["notebook"].id,
notebook_id=current_notebook.id,
)
st.rerun()

View file

@ -1,3 +1,5 @@
from typing import Optional
import streamlit as st
from humanize import naturaltime
from loguru import logger
@ -11,22 +13,20 @@ from .consts import context_icons
@st.dialog("Write a Note", width="large")
def add_note(session_id):
def add_note(notebook_id):
note_title = st.text_input("Title")
note_content = st.text_area("Content")
if st.button("Save", key="add_note"):
logger.debug("Adding note")
note = Note(title=note_title, content=note_content, note_type="human")
note.save()
note.add_to_notebook(st.session_state[session_id]["notebook"].id)
note.add_to_notebook(notebook_id)
st.rerun()
@st.dialog("Add a Source", width="large")
def note_panel(session_id=None, note_id=None):
if note_id:
note: Note = Note.get(note_id)
else:
def note_panel(notebook_id=None, note: Optional[Note] = None):
if not note:
note: Note = Note(note_type="human")
t_preview, t_edit = st.tabs(["Preview", "Edit"])
@ -38,13 +38,13 @@ def note_panel(session_id=None, note_id=None):
note.content = st_monaco(
value=note.content, height="600px", language="markdown"
)
if st.button("Save", key=f"pn_edit_note_{note_id}"):
if st.button("Save", key=f"pn_edit_note_{note.id or 'new'}"):
logger.debug("Editing note")
note.save()
if not note.id:
note.add_to_notebook(st.session_state[session_id]["notebook"].id)
note.add_to_notebook(notebook_id)
st.rerun()
if st.button("Delete", type="primary", key=f"delete_note_{note_id}"):
if st.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"):
logger.debug("Deleting note")
note.delete()
st.rerun()
@ -70,7 +70,7 @@ def make_note_from_chat(content, notebook_id=None):
st.rerun()
def note_card(session_id, note):
def note_card(note, notebook_id):
if note.note_type == "human":
icon = "🤵"
else:
@ -88,9 +88,9 @@ def note_card(session_id, note):
st.caption(f"Updated: {naturaltime(note.updated)}")
if st.button("Expand", icon="📝", key=f"edit_note_{note.id}"):
note_panel(session_id, note.id)
note_panel(notebook_id=notebook_id, note=note)
st.session_state[session_id]["context_config"][note.id] = context_state
st.session_state[notebook_id]["context_config"][note.id] = context_state
def note_list_item(note_id, score=None):
@ -105,4 +105,4 @@ def note_list_item(note_id, score=None):
):
st.write(note.content)
if st.button("Edit Note", icon="📝", key=f"x_edit_note_{note.id}"):
note_panel(note_id=note.id)
note_panel(note=note)

View file

@ -95,6 +95,7 @@ def source_panel(source_id):
if st.button(
"Embed vectors",
icon="🦾",
disabled=source.embedded_chunks > 0,
help="This will generate your embedding vectors on the database for powerful search capabilities",
):
source.vectorize()
@ -119,7 +120,7 @@ def source_panel(source_id):
@st.dialog("Add a Source", width="large")
def add_source(session_id):
def add_source(notebook_id):
source_link = None
source_file = None
source_text = None
@ -167,7 +168,7 @@ def add_source(session_id):
title=result.get("title"),
)
source.save()
source.add_to_notebook(st.session_state[session_id]["notebook"].id)
source.add_to_notebook(notebook_id)
st.write("Summarizing...")
generate_toc_and_title(source)
except UnsupportedTypeException as e:
@ -188,7 +189,7 @@ def add_source(session_id):
st.rerun()
def source_card(session_id, source):
def source_card(source, notebook_id):
# todo: more descriptive icons
icon = "🔗"
@ -208,7 +209,7 @@ def source_card(session_id, source):
if st.button("Expand", icon="📝", key=source.id):
source_panel(source.id)
st.session_state[session_id]["context_config"][source.id] = context_state
st.session_state[notebook_id]["context_config"][source.id] = context_state
def source_list_item(source_id, score=None):

View file

@ -1,8 +1,12 @@
from datetime import datetime
from typing import List, Union
import streamlit as st
from loguru import logger
from open_notebook.database.migrate import MigrationManager
from open_notebook.domain.models import model_manager
from open_notebook.domain.notebook import ChatSession, Notebook
from open_notebook.graphs.chat import ThreadState, graph
from open_notebook.utils import (
compare_versions,
@ -33,19 +37,65 @@ def version_sidebar():
)
def setup_stream_state(session_id) -> None:
def create_session_for_notebook(notebook_id: str, session_name: str = None):
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
title = f"Chat Session {current_time}" if not session_name else session_name
chat_session = ChatSession(title=title)
chat_session.save()
chat_session.relate_to_notebook(notebook_id)
return chat_session
def setup_stream_state(current_notebook: Notebook) -> ChatSession:
"""
Sets the value of the current session_id for langgraph thread state.
If there is no existing thread state for this session_id, it creates a new one.
Finally, it acquires the existing state for the session from Langgraph state and sets it in the streamlit session state.
"""
existing_state = graph.get_state({"configurable": {"thread_id": session_id}}).values
if len(existing_state.keys()) == 0:
st.session_state[session_id] = ThreadState(
assert (
current_notebook is not None and current_notebook.id
), "Current Notebook not selected properly"
if "context_config" not in st.session_state[current_notebook.id]:
st.session_state[current_notebook.id]["context_config"] = {}
current_session_id = st.session_state[current_notebook.id].get("active_session")
# gets the chat session if provided
chat_session: Union[ChatSession, None] = (
ChatSession.get(current_session_id) if current_session_id else None
)
# if there is no chat session, create one or get the first one
if not chat_session:
sessions: List[ChatSession] = current_notebook.chat_sessions
if not sessions or len(sessions) == 0:
logger.debug("Creating new chat session")
chat_session = create_session_for_notebook(current_notebook.id)
else:
logger.debug("Getting last updated session")
chat_session = sessions[0]
logger.debug(f"Chat session: {chat_session}")
if not chat_session or chat_session.id is None:
raise ValueError("Problem acquiring chat session")
# sets the active session for the notebook
st.session_state[current_notebook.id]["active_session"] = chat_session.id
# gets the existing state for the session from Langgraph state
existing_state = graph.get_state(
{"configurable": {"thread_id": chat_session.id}}
).values
if not existing_state or len(existing_state.keys()) == 0:
st.session_state[chat_session.id] = ThreadState(
messages=[], context=None, notebook=None, context_config={}
)
else:
st.session_state[session_id] = existing_state
st.session_state["active_session"] = session_id
st.session_state[chat_session.id] = existing_state
st.session_state[current_notebook.id]["active_session"] = chat_session.id
return chat_session
def check_migration():