Langgraph update (#131)

This commit is contained in:
Sam Partee 2024-10-28 17:08:24 -07:00 committed by GitHub
parent a66cffbcc4
commit 65a3b74fea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,9 +3,9 @@ from datetime import datetime
from configuration import AgentConfigurable
from langchain_arcade import ArcadeToolManager
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.errors import NodeInterrupt
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
@ -34,6 +34,10 @@ model = ChatOpenAI(model="gpt-4o", api_key=openai_api_key).bind_tools(tools)
prompted_model = prompt | model
class AgentState(MessagesState):
auth_url: str | None = None
def call_agent(state):
"""Define the agent function that invokes the model"""
messages = state["messages"]
@ -42,66 +46,53 @@ def call_agent(state):
return {"messages": [response]}
def should_continue(state: MessagesState, config: dict):
def should_continue(state: AgentState, config: dict):
"""Function to determine the next step based on the model's response"""
last_message = state["messages"][-1]
if last_message.tool_calls:
user_id = config["configurable"].get("user_id")
tool_name = state["messages"][-1].tool_calls[0]["name"]
auth_response = toolkit.authorize(tool_name, user_id)
if auth_response.status == "completed":
return "tools"
else:
# If the tool requires authorization, proceed to the authorization step
return "authorization"
return "check_auth"
# If no tool calls are present, end the workflow
return END
def wait_for_auth(state: MessagesState):
last_message = state["messages"][-1]
if isinstance(last_message, HumanMessage):
return "agent"
return "tools"
def check_auth(state: AgentState, config: dict):
user_id = config["configurable"].get("user_id")
tool_name = state["messages"][-1].tool_calls[0]["name"]
auth_response = toolkit.authorize(tool_name, user_id)
if auth_response.status != "completed":
return {"auth_url": auth_response.authorization_url}
else:
return {"auth_url": None}
def authorize(state: MessagesState, config: dict):
def authorize(state: AgentState, config: dict):
"""Function to handle tool authorization"""
user_id = config["configurable"].get("user_id")
tool_name = state["messages"][-1].tool_calls[0]["name"]
auth_response = toolkit.authorize(tool_name, user_id)
auth_message = (
f"Please authorize the application in your browser:\n\n {auth_response.authorization_url}"
)
tool_call_id = state["messages"][-1].tool_calls[0]["id"]
response = ToolMessage(
content=auth_message,
tool_call_id=tool_call_id,
)
# Add the new message to the message history and add a new human message
# saying that the agent should try again
try_message = HumanMessage(
content="Please try the previous tool call again now that you are authorized."
)
return {"messages": [response, try_message]}
if auth_response.status != "completed":
auth_message = (
f"Please authorize the application in your browser:\n\n {state.get('auth_url')}"
)
raise NodeInterrupt(auth_message)
# Build the workflow graph
workflow = StateGraph(MessagesState, AgentConfigurable)
workflow = StateGraph(AgentState, AgentConfigurable)
# Add nodes to the graph
workflow.add_node("agent", call_agent)
workflow.add_node("tools", tool_node)
workflow.add_node("authorization", authorize)
# workflow.add_node("wait_for_auth", wait_for_auth)
workflow.add_node("check_auth", check_auth)
# Define the edges and control flow
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue, ["authorization", "tools", END])
workflow.add_edge("authorization", "agent")
workflow.add_conditional_edges("agent", should_continue, ["check_auth", END])
workflow.add_edge("check_auth", "authorization")
workflow.add_edge("authorization", "tools")
workflow.add_edge("tools", "agent")
# Compile the graph with an interrupt after the authorization node
# so that we can prompt the user to authorize the application
graph = workflow.compile(interrupt_after=["authorization"])
graph = workflow.compile()