Langgraph update (#131)
This commit is contained in:
parent
a66cffbcc4
commit
65a3b74fea
1 changed files with 27 additions and 36 deletions
|
|
@ -3,9 +3,9 @@ from datetime import datetime
|
|||
|
||||
from configuration import AgentConfigurable
|
||||
from langchain_arcade import ArcadeToolManager
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.errors import NodeInterrupt
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
|
|
@ -34,6 +34,10 @@ model = ChatOpenAI(model="gpt-4o", api_key=openai_api_key).bind_tools(tools)
|
|||
prompted_model = prompt | model
|
||||
|
||||
|
||||
class AgentState(MessagesState):
|
||||
auth_url: str | None = None
|
||||
|
||||
|
||||
def call_agent(state):
|
||||
"""Define the agent function that invokes the model"""
|
||||
messages = state["messages"]
|
||||
|
|
@ -42,66 +46,53 @@ def call_agent(state):
|
|||
return {"messages": [response]}
|
||||
|
||||
|
||||
def should_continue(state: MessagesState, config: dict):
|
||||
def should_continue(state: AgentState, config: dict):
|
||||
"""Function to determine the next step based on the model's response"""
|
||||
last_message = state["messages"][-1]
|
||||
if last_message.tool_calls:
|
||||
user_id = config["configurable"].get("user_id")
|
||||
tool_name = state["messages"][-1].tool_calls[0]["name"]
|
||||
auth_response = toolkit.authorize(tool_name, user_id)
|
||||
if auth_response.status == "completed":
|
||||
return "tools"
|
||||
else:
|
||||
# If the tool requires authorization, proceed to the authorization step
|
||||
return "authorization"
|
||||
return "check_auth"
|
||||
# If no tool calls are present, end the workflow
|
||||
return END
|
||||
|
||||
|
||||
def wait_for_auth(state: MessagesState):
|
||||
last_message = state["messages"][-1]
|
||||
if isinstance(last_message, HumanMessage):
|
||||
return "agent"
|
||||
return "tools"
|
||||
def check_auth(state: AgentState, config: dict):
|
||||
user_id = config["configurable"].get("user_id")
|
||||
tool_name = state["messages"][-1].tool_calls[0]["name"]
|
||||
auth_response = toolkit.authorize(tool_name, user_id)
|
||||
if auth_response.status != "completed":
|
||||
return {"auth_url": auth_response.authorization_url}
|
||||
else:
|
||||
return {"auth_url": None}
|
||||
|
||||
|
||||
def authorize(state: MessagesState, config: dict):
|
||||
def authorize(state: AgentState, config: dict):
|
||||
"""Function to handle tool authorization"""
|
||||
user_id = config["configurable"].get("user_id")
|
||||
tool_name = state["messages"][-1].tool_calls[0]["name"]
|
||||
auth_response = toolkit.authorize(tool_name, user_id)
|
||||
|
||||
auth_message = (
|
||||
f"Please authorize the application in your browser:\n\n {auth_response.authorization_url}"
|
||||
)
|
||||
tool_call_id = state["messages"][-1].tool_calls[0]["id"]
|
||||
response = ToolMessage(
|
||||
content=auth_message,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
# Add the new message to the message history and add a new human message
|
||||
# saying that the agent should try again
|
||||
try_message = HumanMessage(
|
||||
content="Please try the previous tool call again now that you are authorized."
|
||||
)
|
||||
return {"messages": [response, try_message]}
|
||||
if auth_response.status != "completed":
|
||||
auth_message = (
|
||||
f"Please authorize the application in your browser:\n\n {state.get('auth_url')}"
|
||||
)
|
||||
raise NodeInterrupt(auth_message)
|
||||
|
||||
|
||||
# Build the workflow graph
|
||||
workflow = StateGraph(MessagesState, AgentConfigurable)
|
||||
workflow = StateGraph(AgentState, AgentConfigurable)
|
||||
|
||||
# Add nodes to the graph
|
||||
workflow.add_node("agent", call_agent)
|
||||
workflow.add_node("tools", tool_node)
|
||||
workflow.add_node("authorization", authorize)
|
||||
# workflow.add_node("wait_for_auth", wait_for_auth)
|
||||
workflow.add_node("check_auth", check_auth)
|
||||
|
||||
# Define the edges and control flow
|
||||
workflow.add_edge(START, "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue, ["authorization", "tools", END])
|
||||
workflow.add_edge("authorization", "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue, ["check_auth", END])
|
||||
workflow.add_edge("check_auth", "authorization")
|
||||
workflow.add_edge("authorization", "tools")
|
||||
workflow.add_edge("tools", "agent")
|
||||
|
||||
# Compile the graph with an interrupt after the authorization node
|
||||
# so that we can prompt the user to authorize the application
|
||||
graph = workflow.compile(interrupt_after=["authorization"])
|
||||
graph = workflow.compile()
|
||||
|
|
|
|||
Loading…
Reference in a new issue