From 65a3b74fea14e21cbacf449441eef7db3558877b Mon Sep 17 00:00:00 2001 From: Sam Partee Date: Mon, 28 Oct 2024 17:08:24 -0700 Subject: [PATCH] Langgraph update (#131) --- examples/langchain/studio/graph.py | 63 +++++++++++++----------------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/examples/langchain/studio/graph.py b/examples/langchain/studio/graph.py index b7615178..f81513a5 100644 --- a/examples/langchain/studio/graph.py +++ b/examples/langchain/studio/graph.py @@ -3,9 +3,9 @@ from datetime import datetime from configuration import AgentConfigurable from langchain_arcade import ArcadeToolManager -from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI +from langgraph.errors import NodeInterrupt from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.prebuilt import ToolNode @@ -34,6 +34,10 @@ model = ChatOpenAI(model="gpt-4o", api_key=openai_api_key).bind_tools(tools) prompted_model = prompt | model +class AgentState(MessagesState): + auth_url: str | None = None + + def call_agent(state): """Define the agent function that invokes the model""" messages = state["messages"] @@ -42,66 +46,53 @@ def call_agent(state): return {"messages": [response]} -def should_continue(state: MessagesState, config: dict): +def should_continue(state: AgentState, config: dict): """Function to determine the next step based on the model's response""" last_message = state["messages"][-1] if last_message.tool_calls: - user_id = config["configurable"].get("user_id") - tool_name = state["messages"][-1].tool_calls[0]["name"] - auth_response = toolkit.authorize(tool_name, user_id) - if auth_response.status == "completed": - return "tools" - else: - # If the tool requires authorization, proceed to the authorization step - return "authorization" + return "check_auth" # If no tool calls are present, end the workflow return END -def wait_for_auth(state: MessagesState): - last_message = state["messages"][-1] - if isinstance(last_message, HumanMessage): - return "agent" - return "tools" +def check_auth(state: AgentState, config: dict): + user_id = config["configurable"].get("user_id") + tool_name = state["messages"][-1].tool_calls[0]["name"] + auth_response = toolkit.authorize(tool_name, user_id) + if auth_response.status != "completed": + return {"auth_url": auth_response.authorization_url} + else: + return {"auth_url": None} -def authorize(state: MessagesState, config: dict): +def authorize(state: AgentState, config: dict): """Function to handle tool authorization""" user_id = config["configurable"].get("user_id") tool_name = state["messages"][-1].tool_calls[0]["name"] auth_response = toolkit.authorize(tool_name, user_id) - - auth_message = ( - f"Please authorize the application in your browser:\n\n {auth_response.authorization_url}" - ) - tool_call_id = state["messages"][-1].tool_calls[0]["id"] - response = ToolMessage( - content=auth_message, - tool_call_id=tool_call_id, - ) - # Add the new message to the message history and add a new human message - # saying that the agent should try again - try_message = HumanMessage( - content="Please try the previous tool call again now that you are authorized." - ) - return {"messages": [response, try_message]} + if auth_response.status != "completed": + auth_message = ( + f"Please authorize the application in your browser:\n\n {state.get('auth_url')}" + ) + raise NodeInterrupt(auth_message) # Build the workflow graph -workflow = StateGraph(MessagesState, AgentConfigurable) +workflow = StateGraph(AgentState, AgentConfigurable) # Add nodes to the graph workflow.add_node("agent", call_agent) workflow.add_node("tools", tool_node) workflow.add_node("authorization", authorize) -# workflow.add_node("wait_for_auth", wait_for_auth) +workflow.add_node("check_auth", check_auth) # Define the edges and control flow workflow.add_edge(START, "agent") -workflow.add_conditional_edges("agent", should_continue, ["authorization", "tools", END]) -workflow.add_edge("authorization", "agent") +workflow.add_conditional_edges("agent", should_continue, ["check_auth", END]) +workflow.add_edge("check_auth", "authorization") +workflow.add_edge("authorization", "tools") workflow.add_edge("tools", "agent") # Compile the graph with an interrupt after the authorization node # so that we can prompt the user to authorize the application -graph = workflow.compile(interrupt_after=["authorization"]) +graph = workflow.compile()