make rag async
This commit is contained in:
parent
e589c7b8aa
commit
d5be2b0d5b
2 changed files with 91 additions and 21 deletions
|
|
@ -49,7 +49,7 @@ class ThreadState(TypedDict):
|
|||
final_answer: str
|
||||
|
||||
|
||||
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
async def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
parser = PydanticOutputParser(pydantic_object=Strategy)
|
||||
system_prompt = Prompter(prompt_template="ask/entry", parser=parser).render(
|
||||
data=state
|
||||
|
|
@ -65,7 +65,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict
|
|||
return {"strategy": ai_message}
|
||||
|
||||
|
||||
def trigger_queries(state: ThreadState, config: RunnableConfig):
|
||||
async def trigger_queries(state: ThreadState, config: RunnableConfig):
|
||||
return [
|
||||
Send(
|
||||
"provide_answer",
|
||||
|
|
@ -80,7 +80,7 @@ def trigger_queries(state: ThreadState, config: RunnableConfig):
|
|||
]
|
||||
|
||||
|
||||
def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
|
||||
async def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
|
||||
payload = state
|
||||
if state["type"] == "text":
|
||||
results = text_search(state["term"], 10, True, True)
|
||||
|
|
@ -100,7 +100,7 @@ def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
|
|||
return {"answers": [ai_message.content]}
|
||||
|
||||
|
||||
def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
async def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
system_prompt = Prompter(prompt_template="ask/final_answer").render(data=state)
|
||||
model = provision_langchain_model(
|
||||
system_prompt,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import asyncio
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from open_notebook.domain.models import Model
|
||||
from open_notebook.domain.notebook import text_search, vector_search
|
||||
from open_notebook.domain.notebook import Note, Notebook, text_search, vector_search
|
||||
from open_notebook.graphs.ask import graph as ask_graph
|
||||
from pages.stream_app.utils import convert_source_references, setup_page
|
||||
|
||||
|
|
@ -12,6 +14,40 @@ ask_tab, search_tab = st.tabs(["Ask Your Knowledge Base (beta)", "Search"])
|
|||
if "search_results" not in st.session_state:
|
||||
st.session_state["search_results"] = []
|
||||
|
||||
if "ask_results" not in st.session_state:
|
||||
st.session_state["ask_results"] = {}
|
||||
|
||||
|
||||
async def process_ask_query(question, strategy_model, answer_model, final_answer_model):
|
||||
async for chunk in ask_graph.astream(
|
||||
input=dict(
|
||||
question=question,
|
||||
),
|
||||
config=dict(
|
||||
configurable=dict(
|
||||
strategy_model=strategy_model.id,
|
||||
answer_model=answer_model.id,
|
||||
final_answer_model=final_answer_model.id,
|
||||
)
|
||||
),
|
||||
stream_mode="updates",
|
||||
):
|
||||
yield (chunk)
|
||||
|
||||
# result = await ask_graph.ainvoke(
|
||||
# dict(
|
||||
# question=question,
|
||||
# ),
|
||||
# config=dict(
|
||||
# configurable=dict(
|
||||
# strategy_model=strategy_model.id,
|
||||
# answer_model=answer_model.id,
|
||||
# final_answer_model=final_answer_model.id,
|
||||
# )
|
||||
# ),
|
||||
# )
|
||||
# return result
|
||||
|
||||
|
||||
def results_card(item):
|
||||
score = item.get("relevance", item.get("similarity", item.get("score", 0)))
|
||||
|
|
@ -49,23 +85,57 @@ with ask_tab:
|
|||
format_func=lambda x: x.name,
|
||||
help="This is the LLM that will be responsible for processing the final answer",
|
||||
)
|
||||
if st.button("Ask"):
|
||||
st.write(f"Searching for {question}")
|
||||
rag_results = ask_graph.invoke(
|
||||
dict(
|
||||
question=question,
|
||||
),
|
||||
config=dict(
|
||||
configurable=dict(
|
||||
strategy_model=strategy_model.id,
|
||||
answer_model=answer_model.id,
|
||||
final_answer_model=final_answer_model.id,
|
||||
ask_bt = st.button("Ask")
|
||||
placeholder = st.container()
|
||||
|
||||
async def stream_results():
|
||||
async for chunk in process_ask_query(
|
||||
question, strategy_model, answer_model, final_answer_model
|
||||
):
|
||||
if "agent" in chunk:
|
||||
with placeholder.expander(
|
||||
f"Agent Strategy: {chunk['agent']['strategy'].reasoning}"
|
||||
):
|
||||
for search in chunk["agent"]["strategy"].searches:
|
||||
st.markdown(f"**{search.type} - {search.term}**")
|
||||
st.markdown(f"Instructions: {search.instructions}")
|
||||
elif "provide_answer" in chunk:
|
||||
for answer in chunk["provide_answer"]["answers"]:
|
||||
with placeholder.expander("Answer"):
|
||||
st.markdown(convert_source_references(answer))
|
||||
elif "write_final_answer" in chunk:
|
||||
st.session_state["ask_results"]["answer"] = chunk["write_final_answer"][
|
||||
"final_answer"
|
||||
]
|
||||
with placeholder.container(border=True):
|
||||
st.markdown(
|
||||
convert_source_references(
|
||||
chunk["write_final_answer"]["final_answer"]
|
||||
)
|
||||
)
|
||||
|
||||
if ask_bt:
|
||||
placeholder.write(f"Searching for {question}")
|
||||
st.session_state["ask_results"]["question"] = question
|
||||
st.session_state["ask_results"]["answer"] = None
|
||||
|
||||
asyncio.run(stream_results())
|
||||
|
||||
if st.session_state["ask_results"].get("answer"):
|
||||
with st.container(border=True):
|
||||
with st.form("save_note_form"):
|
||||
notebook = st.selectbox(
|
||||
"Notebook", Notebook.get_all(), format_func=lambda x: x.name
|
||||
)
|
||||
),
|
||||
)
|
||||
st.markdown(convert_source_references(rag_results["final_answer"]))
|
||||
with st.expander("Details (for debugging)"):
|
||||
st.json(rag_results)
|
||||
if st.form_submit_button("Save Answer as Note"):
|
||||
note = Note(
|
||||
title=st.session_state["ask_results"]["question"],
|
||||
content=st.session_state["ask_results"]["answer"],
|
||||
)
|
||||
note.save()
|
||||
note.add_to_notebook(notebook.id)
|
||||
st.success("Note saved successfully")
|
||||
|
||||
|
||||
with search_tab:
|
||||
with st.container(border=True):
|
||||
|
|
|
|||
Loading…
Reference in a new issue