make rag async

This commit is contained in:
LUIS NOVO 2024-11-09 16:03:41 -03:00
parent e589c7b8aa
commit d5be2b0d5b
2 changed files with 91 additions and 21 deletions

View file

@ -49,7 +49,7 @@ class ThreadState(TypedDict):
final_answer: str
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
async def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
parser = PydanticOutputParser(pydantic_object=Strategy)
system_prompt = Prompter(prompt_template="ask/entry", parser=parser).render(
data=state
@ -65,7 +65,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict
return {"strategy": ai_message}
def trigger_queries(state: ThreadState, config: RunnableConfig):
async def trigger_queries(state: ThreadState, config: RunnableConfig):
return [
Send(
"provide_answer",
@ -80,7 +80,7 @@ def trigger_queries(state: ThreadState, config: RunnableConfig):
]
def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
async def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
payload = state
if state["type"] == "text":
results = text_search(state["term"], 10, True, True)
@ -100,7 +100,7 @@ def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
return {"answers": [ai_message.content]}
def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict:
async def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict:
system_prompt = Prompter(prompt_template="ask/final_answer").render(data=state)
model = provision_langchain_model(
system_prompt,

View file

@ -1,7 +1,9 @@
import asyncio
import streamlit as st
from open_notebook.domain.models import Model
from open_notebook.domain.notebook import text_search, vector_search
from open_notebook.domain.notebook import Note, Notebook, text_search, vector_search
from open_notebook.graphs.ask import graph as ask_graph
from pages.stream_app.utils import convert_source_references, setup_page
@ -12,6 +14,40 @@ ask_tab, search_tab = st.tabs(["Ask Your Knowledge Base (beta)", "Search"])
if "search_results" not in st.session_state:
st.session_state["search_results"] = []
if "ask_results" not in st.session_state:
st.session_state["ask_results"] = {}
async def process_ask_query(question, strategy_model, answer_model, final_answer_model):
async for chunk in ask_graph.astream(
input=dict(
question=question,
),
config=dict(
configurable=dict(
strategy_model=strategy_model.id,
answer_model=answer_model.id,
final_answer_model=final_answer_model.id,
)
),
stream_mode="updates",
):
yield (chunk)
# result = await ask_graph.ainvoke(
# dict(
# question=question,
# ),
# config=dict(
# configurable=dict(
# strategy_model=strategy_model.id,
# answer_model=answer_model.id,
# final_answer_model=final_answer_model.id,
# )
# ),
# )
# return result
def results_card(item):
score = item.get("relevance", item.get("similarity", item.get("score", 0)))
@ -49,23 +85,57 @@ with ask_tab:
format_func=lambda x: x.name,
help="This is the LLM that will be responsible for processing the final answer",
)
if st.button("Ask"):
st.write(f"Searching for {question}")
rag_results = ask_graph.invoke(
dict(
question=question,
),
config=dict(
configurable=dict(
strategy_model=strategy_model.id,
answer_model=answer_model.id,
final_answer_model=final_answer_model.id,
ask_bt = st.button("Ask")
placeholder = st.container()
async def stream_results():
async for chunk in process_ask_query(
question, strategy_model, answer_model, final_answer_model
):
if "agent" in chunk:
with placeholder.expander(
f"Agent Strategy: {chunk['agent']['strategy'].reasoning}"
):
for search in chunk["agent"]["strategy"].searches:
st.markdown(f"**{search.type} - {search.term}**")
st.markdown(f"Instructions: {search.instructions}")
elif "provide_answer" in chunk:
for answer in chunk["provide_answer"]["answers"]:
with placeholder.expander("Answer"):
st.markdown(convert_source_references(answer))
elif "write_final_answer" in chunk:
st.session_state["ask_results"]["answer"] = chunk["write_final_answer"][
"final_answer"
]
with placeholder.container(border=True):
st.markdown(
convert_source_references(
chunk["write_final_answer"]["final_answer"]
)
)
if ask_bt:
placeholder.write(f"Searching for {question}")
st.session_state["ask_results"]["question"] = question
st.session_state["ask_results"]["answer"] = None
asyncio.run(stream_results())
if st.session_state["ask_results"].get("answer"):
with st.container(border=True):
with st.form("save_note_form"):
notebook = st.selectbox(
"Notebook", Notebook.get_all(), format_func=lambda x: x.name
)
),
)
st.markdown(convert_source_references(rag_results["final_answer"]))
with st.expander("Details (for debugging)"):
st.json(rag_results)
if st.form_submit_button("Save Answer as Note"):
note = Note(
title=st.session_state["ask_results"]["question"],
content=st.session_state["ask_results"]["answer"],
)
note.save()
note.add_to_notebook(notebook.id)
st.success("Note saved successfully")
with search_tab:
with st.container(border=True):