enable multiple chat sessions
This commit is contained in:
parent
3be1ecae8a
commit
0f2216207b
8 changed files with 276 additions and 133 deletions
|
|
@ -1,4 +1,11 @@
|
|||
REMOVE FUNCTION fn::vector_search;
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS chat_session SCHEMALESS;
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS refers_to
|
||||
TYPE RELATION
|
||||
FROM chat_session TO notebook;
|
||||
|
||||
REMOVE FUNCTION IF EXISTS fn::vector_search;
|
||||
|
||||
DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_count: int, $sources: bool, $show_notes: bool, $min_similarity: float) {
|
||||
let $source_embedding_search =
|
||||
|
|
@ -16,7 +23,6 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_cou
|
|||
)}
|
||||
ELSE { [] };
|
||||
|
||||
-- Busca em source_insight com threshold
|
||||
let $source_insight_search =
|
||||
IF $sources {(
|
||||
SELECT
|
||||
|
|
@ -67,10 +73,10 @@ DEFINE FUNCTION IF NOT EXISTS fn::vector_search($query: array<float>, $match_cou
|
|||
};
|
||||
|
||||
|
||||
REMOVE FUNCTION fn::text_search;
|
||||
REMOVE FUNCTION IF EXISTS fn::text_search;
|
||||
|
||||
|
||||
DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) {
|
||||
DEFINE FUNCTION IF NOT EXISTS fn::text_search($query_text: string, $match_count: int, $sources:bool, $show_notes:bool) {
|
||||
|
||||
let $source_title_search =
|
||||
IF $sources {(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
REMOVE TABLE IF EXISTS chat_session;
|
||||
|
||||
REMOVE TABLE IF EXISTS refers_to;
|
||||
|
||||
|
||||
REMOVE FUNCTION fn::vector_search;
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from open_notebook.database.repository import (
|
||||
repo_create,
|
||||
repo_query,
|
||||
)
|
||||
from open_notebook.domain.base import ObjectModel
|
||||
|
|
@ -68,6 +66,27 @@ class Notebook(ObjectModel):
|
|||
logger.exception(e)
|
||||
raise DatabaseOperationError(e)
|
||||
|
||||
@property
|
||||
def chat_sessions(self) -> List["ChatSession"]:
|
||||
try:
|
||||
srcs = repo_query(f"""
|
||||
select * from (
|
||||
select
|
||||
<- chat_session as chat_session
|
||||
from refers_to
|
||||
where out={self.id}
|
||||
fetch chat_session
|
||||
)
|
||||
order by chat_session.updated desc
|
||||
""")
|
||||
return (
|
||||
[ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise DatabaseOperationError(e)
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
file_path: Optional[str] = None
|
||||
|
|
@ -99,6 +118,22 @@ class Source(ObjectModel):
|
|||
else:
|
||||
return dict(id=self.id, title=self.title, insights=self.insights)
|
||||
|
||||
@property
|
||||
def embedded_chunks(self) -> int:
|
||||
try:
|
||||
result = repo_query(
|
||||
f"""
|
||||
select count() as chunks from source_embedding where source={self.id} GROUP ALL
|
||||
"""
|
||||
)
|
||||
if len(result) == 0:
|
||||
return 0
|
||||
return result[0]["chunks"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
|
||||
|
||||
@property
|
||||
def insights(self) -> List[SourceInsight]:
|
||||
try:
|
||||
|
|
@ -118,24 +153,6 @@ class Source(ObjectModel):
|
|||
raise InvalidInputError("Notebook ID must be provided")
|
||||
return self.relate("reference", notebook_id)
|
||||
|
||||
def save_chunks(self, text: str) -> None:
|
||||
if not text:
|
||||
raise InvalidInputError("Text cannot be empty")
|
||||
try:
|
||||
chunks = split_text(text, chunk=500000, overlap=1000)
|
||||
logger.debug(f"Split into {len(chunks)} chunks")
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.debug(f"Saving chunk {i}")
|
||||
data = {"source": self.id, "order": i, "content": surreal_clean(chunk)}
|
||||
repo_create(
|
||||
"source_chunk",
|
||||
data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.error(f"Error saving chunks for source {self.id}: {str(e)}")
|
||||
raise DatabaseOperationError(e)
|
||||
|
||||
def vectorize(self) -> None:
|
||||
EMBEDDING_MODEL = model_manager.embedding_model
|
||||
|
||||
|
|
@ -144,8 +161,6 @@ class Source(ObjectModel):
|
|||
return
|
||||
chunks = split_text(
|
||||
self.full_text,
|
||||
chunk=int(os.environ.get("EMBEDDING_CHUNK_SIZE", 1000)),
|
||||
overlap=int(os.environ.get("EMBEDDING_CHUNK_OVERLAP", 1000)),
|
||||
)
|
||||
logger.debug(f"Split into {len(chunks)} chunks")
|
||||
|
||||
|
|
@ -166,26 +181,26 @@ class Source(ObjectModel):
|
|||
logger.exception(e)
|
||||
raise DatabaseOperationError(e)
|
||||
|
||||
@classmethod
|
||||
def search(cls, query: str) -> List[Dict[str, Any]]:
|
||||
if not query:
|
||||
raise InvalidInputError("Search query cannot be empty")
|
||||
try:
|
||||
result = repo_query(
|
||||
"""
|
||||
SELECT * omit full_text
|
||||
FROM source
|
||||
WHERE string::lowercase(title) CONTAINS $query or title @@ $query
|
||||
OR string::lowercase(summary) CONTAINS $query or summary @@ $query
|
||||
OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
|
||||
""",
|
||||
{"query": query},
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching sources: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise DatabaseOperationError("Failed to search sources")
|
||||
# @classmethod
|
||||
# def search(cls, query: str) -> List[Dict[str, Any]]:
|
||||
# if not query:
|
||||
# raise InvalidInputError("Search query cannot be empty")
|
||||
# try:
|
||||
# result = repo_query(
|
||||
# """
|
||||
# SELECT * omit full_text
|
||||
# FROM source
|
||||
# WHERE string::lowercase(title) CONTAINS $query or title @@ $query
|
||||
# OR string::lowercase(summary) CONTAINS $query or summary @@ $query
|
||||
# OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
|
||||
# """,
|
||||
# {"query": query},
|
||||
# )
|
||||
# return result
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error searching sources: {str(e)}")
|
||||
# logger.exception(e)
|
||||
# raise DatabaseOperationError("Failed to search sources")
|
||||
|
||||
def add_insight(self, insight_type: str, content: str) -> Any:
|
||||
EMBEDDING_MODEL = model_manager.embedding_model
|
||||
|
|
@ -246,6 +261,16 @@ class Note(ObjectModel):
|
|||
return self.content
|
||||
|
||||
|
||||
class ChatSession(ObjectModel):
|
||||
table_name: ClassVar[str] = "chat_session"
|
||||
title: Optional[str] = None
|
||||
|
||||
def relate_to_notebook(self, notebook_id: str) -> Any:
|
||||
if not notebook_id:
|
||||
raise InvalidInputError("Notebook ID must be provided")
|
||||
return self.relate("refers_to", notebook_id)
|
||||
|
||||
|
||||
def text_search(keyword: str, results: int, source: bool = True, note: bool = True):
|
||||
if not keyword:
|
||||
raise InvalidInputError("Search keyword cannot be empty")
|
||||
|
|
@ -263,18 +288,6 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
|
|||
raise DatabaseOperationError(e)
|
||||
|
||||
|
||||
# def hybrid_search(
|
||||
# keyword_search: List[str],
|
||||
# embed_search: List[str],
|
||||
# results: int = 50,
|
||||
# source: bool = True,
|
||||
# note: bool = True,
|
||||
# ):
|
||||
# EMBEDDING_MODEL = model_manager.embedding_model
|
||||
# embed1_vector = EMBEDDING_MODEL.embed(embed1) if embed1 else None
|
||||
|
||||
|
||||
# todo: mover o embedding pra ca
|
||||
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
|
||||
if not keyword:
|
||||
raise InvalidInputError("Search keyword cannot be empty")
|
||||
|
|
|
|||
|
|
@ -10,11 +10,14 @@ from pages.stream_app.utils import setup_page, setup_stream_state
|
|||
setup_page("📒 Open Notebook")
|
||||
|
||||
|
||||
def notebook_header(current_notebook):
|
||||
def notebook_header(current_notebook: Notebook):
|
||||
"""
|
||||
Defines the header of the notebook page, including the ability to edit the notebook's name and description.
|
||||
"""
|
||||
c1, c2, c3 = st.columns([8, 2, 2])
|
||||
c1.header(current_notebook.name)
|
||||
if c2.button("Back to the list", icon="🔙"):
|
||||
st.session_state["current_notebook"] = None
|
||||
st.session_state["current_notebook_id"] = None
|
||||
st.rerun()
|
||||
|
||||
if c3.button("Refresh", icon="🔄"):
|
||||
|
|
@ -49,20 +52,20 @@ def notebook_header(current_notebook):
|
|||
st.toast("Notebook unarchived", icon="🗃️")
|
||||
if c3.button("Delete forever", type="primary", icon="☠️"):
|
||||
current_notebook.delete()
|
||||
st.session_state["current_notebook"] = None
|
||||
st.session_state["current_notebook_id"] = None
|
||||
st.rerun()
|
||||
|
||||
|
||||
def notebook_page(current_notebook_id):
|
||||
current_notebook: Notebook = Notebook.get(current_notebook_id)
|
||||
if not current_notebook:
|
||||
st.error("Notebook not found")
|
||||
return
|
||||
if current_notebook_id not in st.session_state.keys():
|
||||
st.session_state[current_notebook_id] = current_notebook
|
||||
def notebook_page(current_notebook: Notebook):
|
||||
# Guarantees that we have an entry for this notebook in the session state
|
||||
if current_notebook.id not in st.session_state:
|
||||
st.session_state[current_notebook.id] = {"notebook": current_notebook}
|
||||
|
||||
# sets up the active session
|
||||
current_session = setup_stream_state(
|
||||
current_notebook=current_notebook,
|
||||
)
|
||||
|
||||
session_id = st.session_state["active_session"]
|
||||
st.session_state[session_id]["notebook"] = current_notebook
|
||||
sources = current_notebook.sources
|
||||
notes = current_notebook.notes
|
||||
|
||||
|
|
@ -74,18 +77,18 @@ def notebook_page(current_notebook_id):
|
|||
with sources_tab:
|
||||
with st.container(border=True):
|
||||
if st.button("Add Source", icon="➕"):
|
||||
add_source(session_id)
|
||||
add_source(current_notebook.id)
|
||||
for source in sources:
|
||||
source_card(session_id=session_id, source=source)
|
||||
source_card(source=source, notebook_id=current_notebook.id)
|
||||
|
||||
with notes_tab:
|
||||
with st.container(border=True):
|
||||
if st.button("Write a Note", icon="📝"):
|
||||
add_note(session_id)
|
||||
add_note(current_notebook.id)
|
||||
for note in notes:
|
||||
note_card(session_id=session_id, note=note)
|
||||
note_card(note=note, notebook_id=current_notebook.id)
|
||||
with chat_tab:
|
||||
chat_sidebar(session_id=session_id)
|
||||
chat_sidebar(current_notebook=current_notebook, current_session=current_session)
|
||||
|
||||
|
||||
def notebook_list_item(notebook):
|
||||
|
|
@ -96,40 +99,50 @@ def notebook_list_item(notebook):
|
|||
)
|
||||
st.write(notebook.description)
|
||||
if st.button("Open", key=f"open_notebook_{notebook.id}"):
|
||||
setup_stream_state(notebook.id)
|
||||
st.session_state["current_notebook"] = notebook.id
|
||||
st.session_state["current_notebook_id"] = notebook.id
|
||||
st.rerun()
|
||||
|
||||
|
||||
if "current_notebook" not in st.session_state:
|
||||
st.session_state["current_notebook"] = None
|
||||
if "current_notebook_id" not in st.session_state:
|
||||
st.session_state["current_notebook_id"] = None
|
||||
|
||||
if st.session_state["current_notebook"]:
|
||||
notebook_page(st.session_state["current_notebook"])
|
||||
# todo: get the notebook, check if it exists and if it's archived
|
||||
if st.session_state["current_notebook_id"]:
|
||||
current_notebook: Notebook = Notebook.get(st.session_state["current_notebook_id"])
|
||||
if not current_notebook:
|
||||
st.error("Notebook not found")
|
||||
st.stop()
|
||||
notebook_page(current_notebook)
|
||||
st.stop()
|
||||
|
||||
st.title("📒 My Notebooks")
|
||||
st.caption("Here are all your notebooks")
|
||||
st.caption(
|
||||
"Notebooks are a great way to organize your thoughts, ideas, and sources. You can create notebooks for different research topics and projects, to create new articles, etc. "
|
||||
)
|
||||
|
||||
with st.expander("➕ **New Notebook**"):
|
||||
new_notebook_title = st.text_input("New Notebook Name")
|
||||
new_notebook_description = st.text_area(
|
||||
"Description",
|
||||
placeholder="Explain the purpose of this notebook. The more details the better.",
|
||||
)
|
||||
if st.button("Create a new Notebook", icon="➕"):
|
||||
notebook = Notebook(
|
||||
name=new_notebook_title, description=new_notebook_description
|
||||
)
|
||||
notebook.save()
|
||||
st.toast("Notebook created successfully", icon="📒")
|
||||
|
||||
notebooks = Notebook.get_all(order_by="updated desc")
|
||||
archived_notebooks = [nb for nb in notebooks if nb.archived]
|
||||
|
||||
for notebook in notebooks:
|
||||
if notebook.archived:
|
||||
continue
|
||||
notebook_list_item(notebook)
|
||||
|
||||
with st.expander("➕ **New Notebook**"):
|
||||
new_notebook_title = st.text_input("New Notebook Name")
|
||||
new_notebook_description = st.text_area("Description")
|
||||
if st.button("Create a new Notebook", icon="➕"):
|
||||
notebook = Notebook(
|
||||
name=new_notebook_title, description=new_notebook_description
|
||||
)
|
||||
notebook.save()
|
||||
st.rerun()
|
||||
|
||||
archived_notebooks = [nb for nb in notebooks if nb.archived]
|
||||
if len(archived_notebooks) > 0:
|
||||
with st.expander(f"**🗃️ {len(archived_notebooks)} archived Notebooks**"):
|
||||
st.write("ℹ Archived Notebooks can still be accessed and used in search.")
|
||||
for notebook in archived_notebooks:
|
||||
notebook_list_item(notebook)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
import humanize
|
||||
import streamlit as st
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from open_notebook.domain.notebook import Note, Source
|
||||
from open_notebook.domain.notebook import ChatSession, Note, Notebook, Source
|
||||
from open_notebook.graphs.chat import graph as chat_graph
|
||||
from open_notebook.plugins.podcasts import PodcastConfig
|
||||
from open_notebook.utils import token_count
|
||||
from pages.stream_app.utils import create_session_for_notebook
|
||||
|
||||
from .note import make_note_from_chat
|
||||
|
||||
|
||||
# todo: build a smarter, more robust context manager function
|
||||
def build_context(session_id):
|
||||
st.session_state[session_id]["context"] = dict(note=[], source=[])
|
||||
def build_context(notebook_id):
|
||||
st.session_state[notebook_id]["context"] = dict(note=[], source=[])
|
||||
|
||||
for id, status in st.session_state[session_id]["context_config"].items():
|
||||
for id, status in st.session_state[notebook_id]["context_config"].items():
|
||||
if not id:
|
||||
continue
|
||||
|
||||
|
|
@ -24,6 +26,7 @@ def build_context(session_id):
|
|||
if "not in" in status:
|
||||
continue
|
||||
|
||||
# todo: there is problably a better way to handle this
|
||||
if item_type == "note":
|
||||
item: Note = Note.get(id)
|
||||
elif item_type == "source":
|
||||
|
|
@ -34,30 +37,33 @@ def build_context(session_id):
|
|||
if not item:
|
||||
continue
|
||||
if "summary" in status:
|
||||
st.session_state[session_id]["context"][item_type] += [
|
||||
st.session_state[notebook_id]["context"][item_type] += [
|
||||
item.get_context(context_size="short")
|
||||
]
|
||||
elif "full content" in status:
|
||||
st.session_state[session_id]["context"][item_type] += [
|
||||
st.session_state[notebook_id]["context"][item_type] += [
|
||||
item.get_context(context_size="long")
|
||||
]
|
||||
|
||||
return st.session_state[session_id]["context"]
|
||||
return st.session_state[notebook_id]["context"]
|
||||
|
||||
|
||||
def execute_chat(txt_input, session_id):
|
||||
current_state = st.session_state[session_id]
|
||||
def execute_chat(txt_input, current_session):
|
||||
current_state = st.session_state[current_session.id]
|
||||
current_state["messages"] += [txt_input]
|
||||
result = chat_graph.invoke(
|
||||
input=current_state,
|
||||
config=RunnableConfig(configurable={"thread_id": session_id}),
|
||||
config=RunnableConfig(configurable={"thread_id": current_session.id}),
|
||||
)
|
||||
current_session.save()
|
||||
return result
|
||||
|
||||
|
||||
def chat_sidebar(session_id):
|
||||
context = build_context(session_id=session_id)
|
||||
tokens = token_count(str(context) + str(st.session_state[session_id]["messages"]))
|
||||
def chat_sidebar(current_notebook: Notebook, current_session: ChatSession):
|
||||
context = build_context(notebook_id=current_notebook.id)
|
||||
tokens = token_count(
|
||||
str(context) + str(st.session_state[current_session.id]["messages"])
|
||||
)
|
||||
chat_tab, podcast_tab = st.tabs(["Chat", "Podcast"])
|
||||
with st.expander(f"Context ({tokens} tokens), {len(str(context))} chars"):
|
||||
st.json(context)
|
||||
|
|
@ -91,15 +97,64 @@ def chat_sidebar(session_id):
|
|||
st.success("Episode generated successfully")
|
||||
st.page_link("pages/5_🎙️_Podcasts.py", label="🎙️ Go to Podcasts")
|
||||
with chat_tab:
|
||||
with st.expander(
|
||||
f"**Session:** {current_session.title} - {humanize.naturaltime(current_session.updated)}"
|
||||
):
|
||||
new_session_name = st.text_input(
|
||||
"Current Session",
|
||||
key="new_session_name",
|
||||
value=current_session.title,
|
||||
)
|
||||
c1, c2 = st.columns(2)
|
||||
if c1.button("Rename", key="rename_session"):
|
||||
current_session.title = new_session_name
|
||||
current_session.save()
|
||||
st.rerun()
|
||||
if c2.button("Delete", key="delete_session_1"):
|
||||
current_session.delete()
|
||||
st.session_state[current_notebook.id]["active_session"] = None
|
||||
st.rerun()
|
||||
st.divider()
|
||||
new_session_name = st.text_input(
|
||||
"New Session Name",
|
||||
key="new_session_name_f",
|
||||
placeholder="Enter a name for the new session...",
|
||||
)
|
||||
st.caption("If no name provided, we'll use the current date.")
|
||||
if st.button("Create New Session", key="create_new_session"):
|
||||
new_session = create_session_for_notebook(
|
||||
notebook_id=current_notebook.id, session_name=new_session_name
|
||||
)
|
||||
st.session_state[current_notebook.id]["active_session"] = new_session.id
|
||||
st.rerun()
|
||||
st.divider()
|
||||
sessions = current_notebook.chat_sessions
|
||||
if len(sessions) > 1:
|
||||
st.markdown("**Other Sessions:**")
|
||||
for session in sessions:
|
||||
if session.id == current_session.id:
|
||||
continue
|
||||
|
||||
st.markdown(
|
||||
f"{session.title} - {humanize.naturaltime(session.updated)}"
|
||||
)
|
||||
if st.button(label="Load", key=f"load_session_{session.id}"):
|
||||
st.session_state[current_notebook.id]["active_session"] = (
|
||||
session.id
|
||||
)
|
||||
st.rerun()
|
||||
with st.container(border=True):
|
||||
request = st.chat_input("Enter your question")
|
||||
# removing for now since it's not multi-model capable right now
|
||||
st.caption(f"Total tokens: {tokens}")
|
||||
if request:
|
||||
response = execute_chat(txt_input=request, session_id=session_id)
|
||||
st.session_state[session_id]["messages"] = response["messages"]
|
||||
response = execute_chat(
|
||||
txt_input=request,
|
||||
current_session=current_session,
|
||||
)
|
||||
st.session_state[current_session.id]["messages"] = response["messages"]
|
||||
|
||||
for msg in st.session_state[session_id]["messages"][::-1]:
|
||||
for msg in st.session_state[current_session.id]["messages"][::-1]:
|
||||
if msg.type not in ["human", "ai"]:
|
||||
continue
|
||||
if not msg.content:
|
||||
|
|
@ -111,6 +166,6 @@ def chat_sidebar(session_id):
|
|||
if st.button("💾 New Note", key=f"render_save_{msg.id}"):
|
||||
make_note_from_chat(
|
||||
content=msg.content,
|
||||
notebook_id=st.session_state[session_id]["notebook"].id,
|
||||
notebook_id=current_notebook.id,
|
||||
)
|
||||
st.rerun()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import streamlit as st
|
||||
from humanize import naturaltime
|
||||
from loguru import logger
|
||||
|
|
@ -11,22 +13,20 @@ from .consts import context_icons
|
|||
|
||||
|
||||
@st.dialog("Write a Note", width="large")
|
||||
def add_note(session_id):
|
||||
def add_note(notebook_id):
|
||||
note_title = st.text_input("Title")
|
||||
note_content = st.text_area("Content")
|
||||
if st.button("Save", key="add_note"):
|
||||
logger.debug("Adding note")
|
||||
note = Note(title=note_title, content=note_content, note_type="human")
|
||||
note.save()
|
||||
note.add_to_notebook(st.session_state[session_id]["notebook"].id)
|
||||
note.add_to_notebook(notebook_id)
|
||||
st.rerun()
|
||||
|
||||
|
||||
@st.dialog("Add a Source", width="large")
|
||||
def note_panel(session_id=None, note_id=None):
|
||||
if note_id:
|
||||
note: Note = Note.get(note_id)
|
||||
else:
|
||||
def note_panel(notebook_id=None, note: Optional[Note] = None):
|
||||
if not note:
|
||||
note: Note = Note(note_type="human")
|
||||
|
||||
t_preview, t_edit = st.tabs(["Preview", "Edit"])
|
||||
|
|
@ -38,13 +38,13 @@ def note_panel(session_id=None, note_id=None):
|
|||
note.content = st_monaco(
|
||||
value=note.content, height="600px", language="markdown"
|
||||
)
|
||||
if st.button("Save", key=f"pn_edit_note_{note_id}"):
|
||||
if st.button("Save", key=f"pn_edit_note_{note.id or 'new'}"):
|
||||
logger.debug("Editing note")
|
||||
note.save()
|
||||
if not note.id:
|
||||
note.add_to_notebook(st.session_state[session_id]["notebook"].id)
|
||||
note.add_to_notebook(notebook_id)
|
||||
st.rerun()
|
||||
if st.button("Delete", type="primary", key=f"delete_note_{note_id}"):
|
||||
if st.button("Delete", type="primary", key=f"delete_note_{note.id or 'new'}"):
|
||||
logger.debug("Deleting note")
|
||||
note.delete()
|
||||
st.rerun()
|
||||
|
|
@ -70,7 +70,7 @@ def make_note_from_chat(content, notebook_id=None):
|
|||
st.rerun()
|
||||
|
||||
|
||||
def note_card(session_id, note):
|
||||
def note_card(note, notebook_id):
|
||||
if note.note_type == "human":
|
||||
icon = "🤵"
|
||||
else:
|
||||
|
|
@ -88,9 +88,9 @@ def note_card(session_id, note):
|
|||
st.caption(f"Updated: {naturaltime(note.updated)}")
|
||||
|
||||
if st.button("Expand", icon="📝", key=f"edit_note_{note.id}"):
|
||||
note_panel(session_id, note.id)
|
||||
note_panel(notebook_id=notebook_id, note=note)
|
||||
|
||||
st.session_state[session_id]["context_config"][note.id] = context_state
|
||||
st.session_state[notebook_id]["context_config"][note.id] = context_state
|
||||
|
||||
|
||||
def note_list_item(note_id, score=None):
|
||||
|
|
@ -105,4 +105,4 @@ def note_list_item(note_id, score=None):
|
|||
):
|
||||
st.write(note.content)
|
||||
if st.button("Edit Note", icon="📝", key=f"x_edit_note_{note.id}"):
|
||||
note_panel(note_id=note.id)
|
||||
note_panel(note=note)
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ def source_panel(source_id):
|
|||
if st.button(
|
||||
"Embed vectors",
|
||||
icon="🦾",
|
||||
disabled=source.embedded_chunks > 0,
|
||||
help="This will generate your embedding vectors on the database for powerful search capabilities",
|
||||
):
|
||||
source.vectorize()
|
||||
|
|
@ -119,7 +120,7 @@ def source_panel(source_id):
|
|||
|
||||
|
||||
@st.dialog("Add a Source", width="large")
|
||||
def add_source(session_id):
|
||||
def add_source(notebook_id):
|
||||
source_link = None
|
||||
source_file = None
|
||||
source_text = None
|
||||
|
|
@ -167,7 +168,7 @@ def add_source(session_id):
|
|||
title=result.get("title"),
|
||||
)
|
||||
source.save()
|
||||
source.add_to_notebook(st.session_state[session_id]["notebook"].id)
|
||||
source.add_to_notebook(notebook_id)
|
||||
st.write("Summarizing...")
|
||||
generate_toc_and_title(source)
|
||||
except UnsupportedTypeException as e:
|
||||
|
|
@ -188,7 +189,7 @@ def add_source(session_id):
|
|||
st.rerun()
|
||||
|
||||
|
||||
def source_card(session_id, source):
|
||||
def source_card(source, notebook_id):
|
||||
# todo: more descriptive icons
|
||||
icon = "🔗"
|
||||
|
||||
|
|
@ -208,7 +209,7 @@ def source_card(session_id, source):
|
|||
if st.button("Expand", icon="📝", key=source.id):
|
||||
source_panel(source.id)
|
||||
|
||||
st.session_state[session_id]["context_config"][source.id] = context_state
|
||||
st.session_state[notebook_id]["context_config"][source.id] = context_state
|
||||
|
||||
|
||||
def source_list_item(source_id, score=None):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Union
|
||||
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
|
||||
from open_notebook.database.migrate import MigrationManager
|
||||
from open_notebook.domain.models import model_manager
|
||||
from open_notebook.domain.notebook import ChatSession, Notebook
|
||||
from open_notebook.graphs.chat import ThreadState, graph
|
||||
from open_notebook.utils import (
|
||||
compare_versions,
|
||||
|
|
@ -33,19 +37,65 @@ def version_sidebar():
|
|||
)
|
||||
|
||||
|
||||
def setup_stream_state(session_id) -> None:
|
||||
def create_session_for_notebook(notebook_id: str, session_name: str = None):
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
title = f"Chat Session {current_time}" if not session_name else session_name
|
||||
chat_session = ChatSession(title=title)
|
||||
chat_session.save()
|
||||
chat_session.relate_to_notebook(notebook_id)
|
||||
return chat_session
|
||||
|
||||
|
||||
def setup_stream_state(current_notebook: Notebook) -> ChatSession:
|
||||
"""
|
||||
Sets the value of the current session_id for langgraph thread state.
|
||||
If there is no existing thread state for this session_id, it creates a new one.
|
||||
Finally, it acquires the existing state for the session from Langgraph state and sets it in the streamlit session state.
|
||||
"""
|
||||
existing_state = graph.get_state({"configurable": {"thread_id": session_id}}).values
|
||||
if len(existing_state.keys()) == 0:
|
||||
st.session_state[session_id] = ThreadState(
|
||||
assert (
|
||||
current_notebook is not None and current_notebook.id
|
||||
), "Current Notebook not selected properly"
|
||||
|
||||
if "context_config" not in st.session_state[current_notebook.id]:
|
||||
st.session_state[current_notebook.id]["context_config"] = {}
|
||||
|
||||
current_session_id = st.session_state[current_notebook.id].get("active_session")
|
||||
|
||||
# gets the chat session if provided
|
||||
chat_session: Union[ChatSession, None] = (
|
||||
ChatSession.get(current_session_id) if current_session_id else None
|
||||
)
|
||||
|
||||
# if there is no chat session, create one or get the first one
|
||||
if not chat_session:
|
||||
sessions: List[ChatSession] = current_notebook.chat_sessions
|
||||
if not sessions or len(sessions) == 0:
|
||||
logger.debug("Creating new chat session")
|
||||
chat_session = create_session_for_notebook(current_notebook.id)
|
||||
else:
|
||||
logger.debug("Getting last updated session")
|
||||
chat_session = sessions[0]
|
||||
|
||||
logger.debug(f"Chat session: {chat_session}")
|
||||
|
||||
if not chat_session or chat_session.id is None:
|
||||
raise ValueError("Problem acquiring chat session")
|
||||
# sets the active session for the notebook
|
||||
st.session_state[current_notebook.id]["active_session"] = chat_session.id
|
||||
|
||||
# gets the existing state for the session from Langgraph state
|
||||
existing_state = graph.get_state(
|
||||
{"configurable": {"thread_id": chat_session.id}}
|
||||
).values
|
||||
if not existing_state or len(existing_state.keys()) == 0:
|
||||
st.session_state[chat_session.id] = ThreadState(
|
||||
messages=[], context=None, notebook=None, context_config={}
|
||||
)
|
||||
else:
|
||||
st.session_state[session_id] = existing_state
|
||||
st.session_state["active_session"] = session_id
|
||||
st.session_state[chat_session.id] = existing_state
|
||||
|
||||
st.session_state[current_notebook.id]["active_session"] = chat_session.id
|
||||
return chat_session
|
||||
|
||||
|
||||
def check_migration():
|
||||
|
|
|
|||
Loading…
Reference in a new issue