openai-agents-python/tests/test_handoff_tool.py

286 lines
8.3 KiB
Python

import json
from typing import Any
import pytest
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
from pydantic import BaseModel
from agents import (
Agent,
Handoff,
HandoffInputData,
MessageOutputItem,
ModelBehaviorError,
RunContextWrapper,
UserError,
handoff,
)
from agents.run import AgentRunner
def message_item(content: str, agent: Agent[Any]) -> MessageOutputItem:
return MessageOutputItem(
agent=agent,
raw_item=ResponseOutputMessage(
id="123",
status="completed",
role="assistant",
type="message",
content=[ResponseOutputText(text=content, type="output_text", annotations=[])],
),
)
def get_len(data: HandoffInputData) -> int:
input_len = len(data.input_history) if isinstance(data.input_history, tuple) else 1
pre_handoff_len = len(data.pre_handoff_items)
new_items_len = len(data.new_items)
return input_len + pre_handoff_len + new_items_len
def test_single_handoff_setup():
agent_1 = Agent(name="test_1")
agent_2 = Agent(name="test_2", handoffs=[agent_1])
assert not agent_1.handoffs
assert agent_2.handoffs == [agent_1]
assert not AgentRunner._get_handoffs(agent_1)
handoff_objects = AgentRunner._get_handoffs(agent_2)
assert len(handoff_objects) == 1
obj = handoff_objects[0]
assert obj.tool_name == Handoff.default_tool_name(agent_1)
assert obj.tool_description == Handoff.default_tool_description(agent_1)
assert obj.agent_name == agent_1.name
def test_multiple_handoffs_setup():
agent_1 = Agent(name="test_1")
agent_2 = Agent(name="test_2")
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
assert agent_3.handoffs == [agent_1, agent_2]
assert not agent_1.handoffs
assert not agent_2.handoffs
handoff_objects = AgentRunner._get_handoffs(agent_3)
assert len(handoff_objects) == 2
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
assert handoff_objects[0].tool_description == Handoff.default_tool_description(agent_1)
assert handoff_objects[1].tool_description == Handoff.default_tool_description(agent_2)
assert handoff_objects[0].agent_name == agent_1.name
assert handoff_objects[1].agent_name == agent_2.name
def test_custom_handoff_setup():
agent_1 = Agent(name="test_1")
agent_2 = Agent(name="test_2")
agent_3 = Agent(
name="test_3",
handoffs=[
agent_1,
handoff(
agent_2,
tool_name_override="custom_tool_name",
tool_description_override="custom tool description",
),
],
)
assert len(agent_3.handoffs) == 2
assert not agent_1.handoffs
assert not agent_2.handoffs
handoff_objects = AgentRunner._get_handoffs(agent_3)
assert len(handoff_objects) == 2
first_handoff = handoff_objects[0]
assert isinstance(first_handoff, Handoff)
assert first_handoff.tool_name == Handoff.default_tool_name(agent_1)
assert first_handoff.tool_description == Handoff.default_tool_description(agent_1)
assert first_handoff.agent_name == agent_1.name
second_handoff = handoff_objects[1]
assert isinstance(second_handoff, Handoff)
assert second_handoff.tool_name == "custom_tool_name"
assert second_handoff.tool_description == "custom tool description"
assert second_handoff.agent_name == agent_2.name
class Foo(BaseModel):
bar: str
@pytest.mark.asyncio
async def test_handoff_input_type():
async def _on_handoff(ctx: RunContextWrapper[Any], input: Foo):
pass
agent = Agent(name="test")
obj = handoff(agent, input_type=Foo, on_handoff=_on_handoff)
for key, value in Foo.model_json_schema().items():
assert obj.input_json_schema[key] == value
# Invalid JSON should raise an error
with pytest.raises(ModelBehaviorError):
await obj.on_invoke_handoff(RunContextWrapper(agent), "not json")
# Empty JSON should raise an error
with pytest.raises(ModelBehaviorError):
await obj.on_invoke_handoff(RunContextWrapper(agent), "")
# Valid JSON should call the on_handoff function
invoked = await obj.on_invoke_handoff(
RunContextWrapper(agent), Foo(bar="baz").model_dump_json()
)
assert invoked == agent
@pytest.mark.asyncio
async def test_on_handoff_called():
was_called = False
async def _on_handoff(ctx: RunContextWrapper[Any], input: Foo):
nonlocal was_called
was_called = True
agent = Agent(name="test")
obj = handoff(agent, input_type=Foo, on_handoff=_on_handoff)
for key, value in Foo.model_json_schema().items():
assert obj.input_json_schema[key] == value
invoked = await obj.on_invoke_handoff(
RunContextWrapper(agent), Foo(bar="baz").model_dump_json()
)
assert invoked == agent
assert was_called, "on_handoff should have been called"
@pytest.mark.asyncio
async def test_on_handoff_without_input_called():
was_called = False
def _on_handoff(ctx: RunContextWrapper[Any]):
nonlocal was_called
was_called = True
agent = Agent(name="test")
obj = handoff(agent, on_handoff=_on_handoff)
invoked = await obj.on_invoke_handoff(RunContextWrapper(agent), "")
assert invoked == agent
assert was_called, "on_handoff should have been called"
@pytest.mark.asyncio
async def test_async_on_handoff_without_input_called():
was_called = False
async def _on_handoff(ctx: RunContextWrapper[Any]):
nonlocal was_called
was_called = True
agent = Agent(name="test")
obj = handoff(agent, on_handoff=_on_handoff)
invoked = await obj.on_invoke_handoff(RunContextWrapper(agent), "")
assert invoked == agent
assert was_called, "on_handoff should have been called"
@pytest.mark.asyncio
async def test_invalid_on_handoff_raises_error():
was_called = False
async def _on_handoff(ctx: RunContextWrapper[Any], blah: str):
nonlocal was_called
was_called = True # pragma: no cover
agent = Agent(name="test")
with pytest.raises(UserError):
# Purposely ignoring the type error here to simulate invalid input
handoff(agent, on_handoff=_on_handoff) # type: ignore
def test_handoff_input_data():
agent = Agent(name="test")
data = HandoffInputData(
input_history="",
pre_handoff_items=(),
new_items=(),
)
assert get_len(data) == 1
data = HandoffInputData(
input_history=({"role": "user", "content": "foo"},),
pre_handoff_items=(),
new_items=(),
)
assert get_len(data) == 1
data = HandoffInputData(
input_history=(
{"role": "user", "content": "foo"},
{"role": "assistant", "content": "bar"},
),
pre_handoff_items=(),
new_items=(),
)
assert get_len(data) == 2
data = HandoffInputData(
input_history=({"role": "user", "content": "foo"},),
pre_handoff_items=(
message_item("foo", agent),
message_item("foo2", agent),
),
new_items=(
message_item("bar", agent),
message_item("baz", agent),
),
)
assert get_len(data) == 5
data = HandoffInputData(
input_history=(
{"role": "user", "content": "foo"},
{"role": "assistant", "content": "bar"},
),
pre_handoff_items=(message_item("baz", agent),),
new_items=(
message_item("baz", agent),
message_item("qux", agent),
),
)
assert get_len(data) == 5
def test_handoff_input_schema_is_strict():
agent = Agent(name="test")
obj = handoff(agent, input_type=Foo, on_handoff=lambda ctx, input: None)
for key, value in Foo.model_json_schema().items():
assert obj.input_json_schema[key] == value
assert obj.strict_json_schema, "Input schema should be strict"
assert (
"additionalProperties" in obj.input_json_schema
and not obj.input_json_schema["additionalProperties"]
), "Input schema should be strict and have additionalProperties=False"
def test_get_transfer_message_is_valid_json() -> None:
agent = Agent(name="foo")
obj = handoff(agent)
transfer = obj.get_transfer_message(agent)
assert json.loads(transfer) == {"assistant": agent.name}