From 281abdf01b0701a30e702af772ae904551645f34 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Wed, 13 Nov 2024 11:55:38 -0300 Subject: [PATCH] improve the accuracy of ids in the citations --- open_notebook/graphs/ask.py | 2 ++ open_notebook/graphs/chat.py | 3 ++- pages/3_🔍_Ask_and_Search.py | 14 -------------- prompts/ask/query_process.jinja | 4 ++++ 4 files changed, 8 insertions(+), 15 deletions(-) diff --git a/open_notebook/graphs/ask.py b/open_notebook/graphs/ask.py index 4586872..8201783 100644 --- a/open_notebook/graphs/ask.py +++ b/open_notebook/graphs/ask.py @@ -89,6 +89,8 @@ async def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict: if len(results) == 0: return {"answers": []} payload["results"] = results + ids = [r["id"] for r in results] + payload["ids"] = ids system_prompt = Prompter(prompt_template="ask/query_process").render(data=payload) model = provision_langchain_model( system_prompt, diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 7342ca3..8d6835a 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -1,6 +1,7 @@ import sqlite3 from typing import Annotated, Optional +from langchain_core.messages import SystemMessage from langchain_core.runnables import ( RunnableConfig, ) @@ -24,7 +25,7 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="chat").render(data=state) - payload = [system_prompt] + state.get("messages", []) + payload = [SystemMessage(content=system_prompt)] + state.get("messages", []) model = provision_langchain_model( str(payload), config.get("configurable", {}).get("model_id"), diff --git a/pages/3_🔍_Ask_and_Search.py b/pages/3_🔍_Ask_and_Search.py index 0b0869a..b07e8e5 100644 --- a/pages/3_🔍_Ask_and_Search.py +++ b/pages/3_🔍_Ask_and_Search.py @@ -34,20 +34,6 @@ async def process_ask_query(question, strategy_model, answer_model, final_answer ): yield (chunk) - # result = await ask_graph.ainvoke( - # dict( - # question=question, - # ), - # config=dict( - # configurable=dict( - # strategy_model=strategy_model.id, - # answer_model=answer_model.id, - # final_answer_model=final_answer_model.id, - # ) - # ), - # ) - # return result - def results_card(item): score = item.get("relevance", item.get("similarity", item.get("score", 0))) diff --git a/prompts/ask/query_process.jinja b/prompts/ask/query_process.jinja index 17b0d4d..e787fab 100644 --- a/prompts/ask/query_process.jinja +++ b/prompts/ask/query_process.jinja @@ -45,6 +45,10 @@ Please note, "note:iuiodadalknda" and "insight:adadadadadadad" are examples of d - Do not assume or change the type prefix of any document ID. If a document ID is "note:xyz", use it exactly as "note:xyz". Do not change it to "source:xyz" or any other variation. - **Use document IDs exactly as they are returned from the search tool. Do not add any prefixes or modify them in any way.** +## IDs PROVIDED IN THIS QUERY + +You have been given the following content ids to work from: {{ids}} +So, if you are citing some document, it should be one of these. # YOUR ANSWER