open-notebook/open_notebook/graphs/ask.py
Luis Novo 20e18fdd0d feat: improve error clarity for LLM provider failures (#506)
Replace generic "An unexpected error occurred" messages with descriptive,
user-friendly error messages when LLM operations fail. Errors like invalid
API keys, wrong model names, and rate limits now surface clearly in the UI.

Adds error classification utility, global FastAPI exception handlers, and
frontend getApiErrorMessage() helper. Bumps version to 1.7.2.
2026-02-16 16:15:46 -03:00

166 lines
5.5 KiB
Python

import operator
from typing import Annotated, List
from ai_prompter import Prompter
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from open_notebook.ai.provision import provision_langchain_model
from open_notebook.domain.notebook import vector_search
from open_notebook.exceptions import OpenNotebookError
from open_notebook.utils import clean_thinking_content
from open_notebook.utils.error_classifier import classify_error
class SubGraphState(TypedDict):
question: str
term: str
instructions: str
results: dict
answer: str
ids: list # Added for provide_answer function
class Search(BaseModel):
term: str
instructions: str = Field(
description="Tell the answeting LLM what information you need extracted from this search"
)
class Strategy(BaseModel):
reasoning: str
searches: List[Search] = Field(
default_factory=list,
description="You can add up to five searches to this strategy",
)
class ThreadState(TypedDict):
question: str
strategy: Strategy
answers: Annotated[list, operator.add]
final_answer: str
async def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
try:
parser = PydanticOutputParser(pydantic_object=Strategy)
system_prompt = Prompter(prompt_template="ask/entry", parser=parser).render( # type: ignore[arg-type]
data=state # type: ignore[arg-type]
)
model = await provision_langchain_model(
system_prompt,
config.get("configurable", {}).get("strategy_model"),
"tools",
max_tokens=2000,
structured=dict(type="json"),
)
# model = model.bind_tools(tools)
# First get the raw response from the model
ai_message = await model.ainvoke(system_prompt)
# Clean the thinking content from the response
message_content = (
ai_message.content
if isinstance(ai_message.content, str)
else str(ai_message.content)
)
cleaned_content = clean_thinking_content(message_content)
# Parse the cleaned JSON content
strategy = parser.parse(cleaned_content)
return {"strategy": strategy}
except OpenNotebookError:
raise
except Exception as e:
error_class, user_message = classify_error(e)
raise error_class(user_message) from e
async def trigger_queries(state: ThreadState, config: RunnableConfig):
return [
Send(
"provide_answer",
{
"question": state["question"],
"instructions": s.instructions,
"term": s.term,
# "type": s.type,
},
)
for s in state["strategy"].searches
]
async def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
try:
payload = state
# if state["type"] == "text":
# results = text_search(state["term"], 10, True, True)
# else:
results = await vector_search(state["term"], 10, True, True)
if len(results) == 0:
return {"answers": []}
payload["results"] = results
ids = [r["id"] for r in results]
payload["ids"] = ids
system_prompt = Prompter(prompt_template="ask/query_process").render(data=payload) # type: ignore[arg-type]
model = await provision_langchain_model(
system_prompt,
config.get("configurable", {}).get("answer_model"),
"tools",
max_tokens=2000,
)
ai_message = await model.ainvoke(system_prompt)
ai_content = (
ai_message.content
if isinstance(ai_message.content, str)
else str(ai_message.content)
)
return {"answers": [clean_thinking_content(ai_content)]}
except OpenNotebookError:
raise
except Exception as e:
error_class, user_message = classify_error(e)
raise error_class(user_message) from e
async def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict:
try:
system_prompt = Prompter(prompt_template="ask/final_answer").render(data=state) # type: ignore[arg-type]
model = await provision_langchain_model(
system_prompt,
config.get("configurable", {}).get("final_answer_model"),
"tools",
max_tokens=2000,
)
ai_message = await model.ainvoke(system_prompt)
final_content = (
ai_message.content
if isinstance(ai_message.content, str)
else str(ai_message.content)
)
return {"final_answer": clean_thinking_content(final_content)}
except OpenNotebookError:
raise
except Exception as e:
error_class, user_message = classify_error(e)
raise error_class(user_message) from e
agent_state = StateGraph(ThreadState)
agent_state.add_node("agent", call_model_with_messages)
agent_state.add_node("provide_answer", provide_answer)
agent_state.add_node("write_final_answer", write_final_answer)
agent_state.add_edge(START, "agent")
agent_state.add_conditional_edges("agent", trigger_queries, ["provide_answer"])
agent_state.add_edge("provide_answer", "write_final_answer")
agent_state.add_edge("write_final_answer", END)
graph = agent_state.compile()