diff --git a/open_notebook/graphs/ask.py b/open_notebook/graphs/ask.py index a2746db..4586872 100644 --- a/open_notebook/graphs/ask.py +++ b/open_notebook/graphs/ask.py @@ -49,7 +49,7 @@ class ThreadState(TypedDict): final_answer: str -def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: +async def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: parser = PydanticOutputParser(pydantic_object=Strategy) system_prompt = Prompter(prompt_template="ask/entry", parser=parser).render( data=state @@ -65,7 +65,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict return {"strategy": ai_message} -def trigger_queries(state: ThreadState, config: RunnableConfig): +async def trigger_queries(state: ThreadState, config: RunnableConfig): return [ Send( "provide_answer", @@ -80,7 +80,7 @@ def trigger_queries(state: ThreadState, config: RunnableConfig): ] -def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict: +async def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict: payload = state if state["type"] == "text": results = text_search(state["term"], 10, True, True) @@ -100,7 +100,7 @@ def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict: return {"answers": [ai_message.content]} -def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict: +async def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="ask/final_answer").render(data=state) model = provision_langchain_model( system_prompt, diff --git a/pages/3_🔍_Ask_and_Search.py b/pages/3_🔍_Ask_and_Search.py index e2da0c2..0b0869a 100644 --- a/pages/3_🔍_Ask_and_Search.py +++ b/pages/3_🔍_Ask_and_Search.py @@ -1,7 +1,9 @@ +import asyncio + import streamlit as st from open_notebook.domain.models import Model -from open_notebook.domain.notebook import text_search, vector_search +from open_notebook.domain.notebook import Note, Notebook, text_search, vector_search from open_notebook.graphs.ask import graph as ask_graph from pages.stream_app.utils import convert_source_references, setup_page @@ -12,6 +14,40 @@ ask_tab, search_tab = st.tabs(["Ask Your Knowledge Base (beta)", "Search"]) if "search_results" not in st.session_state: st.session_state["search_results"] = [] +if "ask_results" not in st.session_state: + st.session_state["ask_results"] = {} + + +async def process_ask_query(question, strategy_model, answer_model, final_answer_model): + async for chunk in ask_graph.astream( + input=dict( + question=question, + ), + config=dict( + configurable=dict( + strategy_model=strategy_model.id, + answer_model=answer_model.id, + final_answer_model=final_answer_model.id, + ) + ), + stream_mode="updates", + ): + yield (chunk) + + # result = await ask_graph.ainvoke( + # dict( + # question=question, + # ), + # config=dict( + # configurable=dict( + # strategy_model=strategy_model.id, + # answer_model=answer_model.id, + # final_answer_model=final_answer_model.id, + # ) + # ), + # ) + # return result + def results_card(item): score = item.get("relevance", item.get("similarity", item.get("score", 0))) @@ -49,23 +85,57 @@ with ask_tab: format_func=lambda x: x.name, help="This is the LLM that will be responsible for processing the final answer", ) - if st.button("Ask"): - st.write(f"Searching for {question}") - rag_results = ask_graph.invoke( - dict( - question=question, - ), - config=dict( - configurable=dict( - strategy_model=strategy_model.id, - answer_model=answer_model.id, - final_answer_model=final_answer_model.id, + ask_bt = st.button("Ask") + placeholder = st.container() + + async def stream_results(): + async for chunk in process_ask_query( + question, strategy_model, answer_model, final_answer_model + ): + if "agent" in chunk: + with placeholder.expander( + f"Agent Strategy: {chunk['agent']['strategy'].reasoning}" + ): + for search in chunk["agent"]["strategy"].searches: + st.markdown(f"**{search.type} - {search.term}**") + st.markdown(f"Instructions: {search.instructions}") + elif "provide_answer" in chunk: + for answer in chunk["provide_answer"]["answers"]: + with placeholder.expander("Answer"): + st.markdown(convert_source_references(answer)) + elif "write_final_answer" in chunk: + st.session_state["ask_results"]["answer"] = chunk["write_final_answer"][ + "final_answer" + ] + with placeholder.container(border=True): + st.markdown( + convert_source_references( + chunk["write_final_answer"]["final_answer"] + ) + ) + + if ask_bt: + placeholder.write(f"Searching for {question}") + st.session_state["ask_results"]["question"] = question + st.session_state["ask_results"]["answer"] = None + + asyncio.run(stream_results()) + + if st.session_state["ask_results"].get("answer"): + with st.container(border=True): + with st.form("save_note_form"): + notebook = st.selectbox( + "Notebook", Notebook.get_all(), format_func=lambda x: x.name ) - ), - ) - st.markdown(convert_source_references(rag_results["final_answer"])) - with st.expander("Details (for debugging)"): - st.json(rag_results) + if st.form_submit_button("Save Answer as Note"): + note = Note( + title=st.session_state["ask_results"]["question"], + content=st.session_state["ask_results"]["answer"], + ) + note.save() + note.add_to_notebook(notebook.id) + st.success("Note saved successfully") + with search_tab: with st.container(border=True):