From 809ecb45e1f72628bd1604441bbb3c4b6e044b1d Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Tue, 22 Oct 2024 16:36:21 -0300 Subject: [PATCH] rename doc_query tool --- .../graphs/{ask_content.py => doc_query.py} | 20 ++++++++++--------- open_notebook/graphs/tools.py | 12 ++++++----- 2 files changed, 18 insertions(+), 14 deletions(-) rename open_notebook/graphs/{ask_content.py => doc_query.py} (68%) diff --git a/open_notebook/graphs/ask_content.py b/open_notebook/graphs/doc_query.py similarity index 68% rename from open_notebook/graphs/ask_content.py rename to open_notebook/graphs/doc_query.py index 41d4a40..8b22e6f 100644 --- a/open_notebook/graphs/ask_content.py +++ b/open_notebook/graphs/doc_query.py @@ -3,16 +3,16 @@ import os from langchain_core.runnables import ( RunnableConfig, ) -from langchain_openai import ChatOpenAI from langgraph.graph import END, START, StateGraph from loguru import logger from typing_extensions import TypedDict from open_notebook.domain import Note, Notebook, Source +from open_notebook.model_configs import get_langchain_model from open_notebook.prompter import Prompter -class AskState(TypedDict): +class DocQueryState(TypedDict): doc_id: str doc_content: str question: str @@ -20,11 +20,13 @@ class AskState(TypedDict): notebook: Notebook -def call_model_with_messages(state: AskState, config: RunnableConfig) -> dict: - model = ChatOpenAI( - model=os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"]), - temperature=0, - ) +def call_model_with_messages(state: DocQueryState, config: RunnableConfig) -> dict: + if config.get("configurable", {}).get("model_name", None): + model_name = config.get("configurable", {}).get("model_name", None) + else: + model_name = os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"]) + + model = get_langchain_model(model_name) system_prompt = Prompter(prompt_template="ask_content").render(data=state) logger.debug(f"System prompt: {system_prompt}") ai_message = model.invoke(system_prompt) @@ -32,7 +34,7 @@ def call_model_with_messages(state: AskState, config: RunnableConfig) -> dict: # todo: there is probably a better way to do this and avoid repetition -def get_content(state: AskState) -> dict: +def get_content(state: DocQueryState) -> dict: doc_id = state["doc_id"] if "note:" in doc_id: doc: Note = Note.get(id=doc_id) @@ -42,7 +44,7 @@ def get_content(state: AskState) -> dict: return {"doc_content": doc_content} -agent_state = StateGraph(AskState) +agent_state = StateGraph(DocQueryState) agent_state.add_node("get_content", get_content) agent_state.add_node("agent", call_model_with_messages) agent_state.add_edge(START, "get_content") diff --git a/open_notebook/graphs/tools.py b/open_notebook/graphs/tools.py index 2c35c6f..636e25b 100644 --- a/open_notebook/graphs/tools.py +++ b/open_notebook/graphs/tools.py @@ -6,19 +6,21 @@ from langchain.tools import tool @tool def get_current_timestamp() -> str: """ + name: get_current_timestamp Returns the current timestamp in the format YYYYMMDDHHmmss. """ return datetime.now().strftime("%Y%m%d%H%M%S") @tool -def ask_the_document(doc_id: str, question: str): +def doc_query(doc_id: str, question: str): """ - Use this tool to ask a question to the document. - Another LLM will ready the document and answer the question. - Be specific and complete in your query given the LLM that will process it is very capable. + name: doc_query + Use this tool if you need to investigate into a particular document. + Another LLM will read the document and answer the question that you might have. + Use this when the user question cannot be answered with the content you have in context. """ - from open_notebook.graphs.ask_content import graph + from open_notebook.graphs.doc_query import graph result = graph.invoke({"doc_id": doc_id, "question": question}) return result["answer"]