from __future__ import annotations import json from typing import Any import pytest from typing_extensions import TypedDict from agents import ( Agent, GuardrailFunctionOutput, Handoff, HandoffInputData, InputGuardrail, InputGuardrailTripwireTriggered, OutputGuardrail, OutputGuardrailTripwireTriggered, RunContextWrapper, Runner, UserError, handoff, ) from agents.items import RunItem from agents.run import RunConfig from agents.stream_events import AgentUpdatedStreamEvent from .fake_model import FakeModel from .test_responses import ( get_final_output_message, get_function_tool, get_function_tool_call, get_handoff_tool_call, get_text_input_item, get_text_message, ) @pytest.mark.asyncio async def test_simple_first_run(): model = FakeModel() agent = Agent( name="test", model=model, ) model.set_next_output([get_text_message("first")]) result = Runner.run_streamed(agent, input="test") async for _ in result.stream_events(): pass assert result.input == "test" assert len(result.new_items) == 1, "exactly one item should be generated" assert result.final_output == "first" assert len(result.raw_responses) == 1, "exactly one model response should be generated" assert result.raw_responses[0].output == [get_text_message("first")] assert result.last_agent == agent assert len(result.to_input_list()) == 2, "should have original input and generated item" model.set_next_output([get_text_message("second")]) result = Runner.run_streamed( agent, input=[get_text_input_item("message"), get_text_input_item("another_message")] ) async for _ in result.stream_events(): pass assert len(result.new_items) == 1, "exactly one item should be generated" assert result.final_output == "second" assert len(result.raw_responses) == 1, "exactly one model response should be generated" assert len(result.to_input_list()) == 3, "should have original input and generated item" @pytest.mark.asyncio async def test_subsequent_runs(): model = FakeModel() agent = Agent( name="test", model=model, ) model.set_next_output([get_text_message("third")]) result = Runner.run_streamed(agent, input="test") async for _ in result.stream_events(): pass assert result.input == "test" assert len(result.new_items) == 1, "exactly one item should be generated" assert len(result.to_input_list()) == 2, "should have original input and generated item" model.set_next_output([get_text_message("fourth")]) result = Runner.run_streamed(agent, input=result.to_input_list()) async for _ in result.stream_events(): pass assert len(result.input) == 2, f"should have previous input but got {result.input}" assert len(result.new_items) == 1, "exactly one item should be generated" assert result.final_output == "fourth" assert len(result.raw_responses) == 1, "exactly one model response should be generated" assert result.raw_responses[0].output == [get_text_message("fourth")] assert result.last_agent == agent assert len(result.to_input_list()) == 3, "should have original input and generated items" @pytest.mark.asyncio async def test_tool_call_runs(): model = FakeModel() agent = Agent( name="test", model=model, tools=[get_function_tool("foo", "tool_result")], ) model.add_multiple_turn_outputs( [ # First turn: a message and tool call [get_text_message("a_message"), get_function_tool_call("foo", json.dumps({"a": "b"}))], # Second turn: text message [get_text_message("done")], ] ) result = Runner.run_streamed(agent, input="user_message") async for _ in result.stream_events(): pass assert result.final_output == "done" assert len(result.raw_responses) == 2, ( "should have two responses: the first which produces a tool call, and the second which" "handles the tool result" ) assert len(result.to_input_list()) == 5, ( "should have five inputs: the original input, the message, the tool call, the tool result " "and the done message" ) @pytest.mark.asyncio async def test_handoffs(): model = FakeModel() agent_1 = Agent( name="test", model=model, ) agent_2 = Agent( name="test", model=model, ) agent_3 = Agent( name="test", model=model, handoffs=[agent_1, agent_2], tools=[get_function_tool("some_function", "result")], ) model.add_multiple_turn_outputs( [ # First turn: a tool call [get_function_tool_call("some_function", json.dumps({"a": "b"}))], # Second turn: a message and a handoff [get_text_message("a_message"), get_handoff_tool_call(agent_1)], # Third turn: text message [get_text_message("done")], ] ) result = Runner.run_streamed(agent_3, input="user_message") async for _ in result.stream_events(): pass assert result.final_output == "done" assert len(result.raw_responses) == 3, "should have three model responses" assert len(result.to_input_list()) == 7, ( "should have 7 inputs: orig input, tool call, tool result, message, handoff, handoff" "result, and done message" ) assert result.last_agent == agent_1, "should have handed off to agent_1" class Foo(TypedDict): bar: str @pytest.mark.asyncio async def test_structured_output(): model = FakeModel() agent_1 = Agent( name="test", model=model, tools=[get_function_tool("bar", "bar_result")], output_type=Foo, ) agent_2 = Agent( name="test", model=model, tools=[get_function_tool("foo", "foo_result")], handoffs=[agent_1], ) model.add_multiple_turn_outputs( [ # First turn: a tool call [get_function_tool_call("foo", json.dumps({"bar": "baz"}))], # Second turn: a message and a handoff [get_text_message("a_message"), get_handoff_tool_call(agent_1)], # Third turn: tool call and structured output [ get_function_tool_call("bar", json.dumps({"bar": "baz"})), get_final_output_message(json.dumps(Foo(bar="baz"))), ], ] ) result = Runner.run_streamed( agent_2, input=[ get_text_input_item("user_message"), get_text_input_item("another_message"), ], ) async for _ in result.stream_events(): pass assert result.final_output == Foo(bar="baz") assert len(result.raw_responses) == 3, "should have three model responses" assert len(result.to_input_list()) == 10, ( "should have input: 2 orig inputs, function call, function call result, message, handoff, " "handoff output, tool call, tool call result, final output" ) assert result.last_agent == agent_1, "should have handed off to agent_1" assert result.final_output == Foo(bar="baz"), "should have structured output" def remove_new_items(handoff_input_data: HandoffInputData) -> HandoffInputData: return HandoffInputData( input_history=handoff_input_data.input_history, pre_handoff_items=(), new_items=(), ) @pytest.mark.asyncio async def test_handoff_filters(): model = FakeModel() agent_1 = Agent( name="test", model=model, ) agent_2 = Agent( name="test", model=model, handoffs=[ handoff( agent=agent_1, input_filter=remove_new_items, ) ], ) model.add_multiple_turn_outputs( [ [get_text_message("1"), get_text_message("2"), get_handoff_tool_call(agent_1)], [get_text_message("last")], ] ) result = Runner.run_streamed(agent_2, input="user_message") async for _ in result.stream_events(): pass assert result.final_output == "last" assert len(result.raw_responses) == 2, "should have two model responses" assert len(result.to_input_list()) == 2, ( "should only have 2 inputs: orig input and last message" ) @pytest.mark.asyncio async def test_async_input_filter_fails(): # DO NOT rename this without updating pyproject.toml model = FakeModel() agent_1 = Agent( name="test", model=model, ) async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]: return agent_1 async def invalid_input_filter(data: HandoffInputData) -> HandoffInputData: return data # pragma: no cover agent_2 = Agent[None]( name="test", model=model, handoffs=[ Handoff( tool_name=Handoff.default_tool_name(agent_1), tool_description=Handoff.default_tool_description(agent_1), input_json_schema={}, on_invoke_handoff=on_invoke_handoff, agent_name=agent_1.name, # Purposely ignoring the type error here to simulate invalid input input_filter=invalid_input_filter, # type: ignore ) ], ) model.add_multiple_turn_outputs( [ [get_text_message("1"), get_text_message("2"), get_handoff_tool_call(agent_1)], [get_text_message("last")], ] ) with pytest.raises(UserError): result = Runner.run_streamed(agent_2, input="user_message") async for _ in result.stream_events(): pass @pytest.mark.asyncio async def test_invalid_input_filter_fails(): model = FakeModel() agent_1 = Agent( name="test", model=model, ) async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]: return agent_1 def invalid_input_filter(data: HandoffInputData) -> HandoffInputData: # Purposely returning a string to simulate invalid output return "foo" # type: ignore agent_2 = Agent[None]( name="test", model=model, handoffs=[ Handoff( tool_name=Handoff.default_tool_name(agent_1), tool_description=Handoff.default_tool_description(agent_1), input_json_schema={}, on_invoke_handoff=on_invoke_handoff, agent_name=agent_1.name, input_filter=invalid_input_filter, ) ], ) model.add_multiple_turn_outputs( [ [get_text_message("1"), get_text_message("2"), get_handoff_tool_call(agent_1)], [get_text_message("last")], ] ) with pytest.raises(UserError): result = Runner.run_streamed(agent_2, input="user_message") async for _ in result.stream_events(): pass @pytest.mark.asyncio async def test_non_callable_input_filter_causes_error(): model = FakeModel() agent_1 = Agent( name="test", model=model, ) async def on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Agent[Any]: return agent_1 agent_2 = Agent[None]( name="test", model=model, handoffs=[ Handoff( tool_name=Handoff.default_tool_name(agent_1), tool_description=Handoff.default_tool_description(agent_1), input_json_schema={}, on_invoke_handoff=on_invoke_handoff, agent_name=agent_1.name, # Purposely ignoring the type error here to simulate invalid input input_filter="foo", # type: ignore ) ], ) model.add_multiple_turn_outputs( [ [get_text_message("1"), get_text_message("2"), get_handoff_tool_call(agent_1)], [get_text_message("last")], ] ) with pytest.raises(UserError): result = Runner.run_streamed(agent_2, input="user_message") async for _ in result.stream_events(): pass @pytest.mark.asyncio async def test_handoff_on_input(): call_output: str | None = None def on_input(_ctx: RunContextWrapper[Any], data: Foo) -> None: nonlocal call_output call_output = data["bar"] model = FakeModel() agent_1 = Agent( name="test", model=model, ) agent_2 = Agent( name="test", model=model, handoffs=[ handoff( agent=agent_1, on_handoff=on_input, input_type=Foo, ) ], ) model.add_multiple_turn_outputs( [ [ get_text_message("1"), get_text_message("2"), get_handoff_tool_call(agent_1, args=json.dumps(Foo(bar="test_input"))), ], [get_text_message("last")], ] ) result = Runner.run_streamed(agent_2, input="user_message") async for _ in result.stream_events(): pass assert result.final_output == "last" assert call_output == "test_input", "should have called the handoff with the correct input" @pytest.mark.asyncio async def test_async_handoff_on_input(): call_output: str | None = None async def on_input(_ctx: RunContextWrapper[Any], data: Foo) -> None: nonlocal call_output call_output = data["bar"] model = FakeModel() agent_1 = Agent( name="test", model=model, ) agent_2 = Agent( name="test", model=model, handoffs=[ handoff( agent=agent_1, on_handoff=on_input, input_type=Foo, ) ], ) model.add_multiple_turn_outputs( [ [ get_text_message("1"), get_text_message("2"), get_handoff_tool_call(agent_1, args=json.dumps(Foo(bar="test_input"))), ], [get_text_message("last")], ] ) result = Runner.run_streamed(agent_2, input="user_message") async for _ in result.stream_events(): pass assert result.final_output == "last" assert call_output == "test_input", "should have called the handoff with the correct input" @pytest.mark.asyncio async def test_input_guardrail_tripwire_triggered_causes_exception_streamed(): def guardrail_function( context: RunContextWrapper[Any], agent: Agent[Any], input: Any ) -> GuardrailFunctionOutput: return GuardrailFunctionOutput( output_info=None, tripwire_triggered=True, ) agent = Agent( name="test", input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], model=FakeModel(), ) with pytest.raises(InputGuardrailTripwireTriggered): result = Runner.run_streamed(agent, input="user_message") async for _ in result.stream_events(): pass @pytest.mark.asyncio async def test_output_guardrail_tripwire_triggered_causes_exception_streamed(): def guardrail_function( context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any ) -> GuardrailFunctionOutput: return GuardrailFunctionOutput( output_info=None, tripwire_triggered=True, ) model = FakeModel(initial_output=[get_text_message("first_test")]) agent = Agent( name="test", output_guardrails=[OutputGuardrail(guardrail_function=guardrail_function)], model=model, ) with pytest.raises(OutputGuardrailTripwireTriggered): result = Runner.run_streamed(agent, input="user_message") async for _ in result.stream_events(): pass @pytest.mark.asyncio async def test_run_input_guardrail_tripwire_triggered_causes_exception_streamed(): def guardrail_function( context: RunContextWrapper[Any], agent: Agent[Any], input: Any ) -> GuardrailFunctionOutput: return GuardrailFunctionOutput( output_info=None, tripwire_triggered=True, ) agent = Agent( name="test", model=FakeModel(), ) with pytest.raises(InputGuardrailTripwireTriggered): result = Runner.run_streamed( agent, input="user_message", run_config=RunConfig( input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)] ), ) async for _ in result.stream_events(): pass @pytest.mark.asyncio async def test_run_output_guardrail_tripwire_triggered_causes_exception_streamed(): def guardrail_function( context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any ) -> GuardrailFunctionOutput: return GuardrailFunctionOutput( output_info=None, tripwire_triggered=True, ) model = FakeModel(initial_output=[get_text_message("first_test")]) agent = Agent( name="test", model=model, ) with pytest.raises(OutputGuardrailTripwireTriggered): result = Runner.run_streamed( agent, input="user_message", run_config=RunConfig( output_guardrails=[OutputGuardrail(guardrail_function=guardrail_function)] ), ) async for _ in result.stream_events(): pass @pytest.mark.asyncio async def test_streaming_events(): model = FakeModel() agent_1 = Agent( name="test", model=model, tools=[get_function_tool("bar", "bar_result")], output_type=Foo, ) agent_2 = Agent( name="test", model=model, tools=[get_function_tool("foo", "foo_result")], handoffs=[agent_1], ) model.add_multiple_turn_outputs( [ # First turn: a tool call [get_function_tool_call("foo", json.dumps({"bar": "baz"}))], # Second turn: a message and a handoff [get_text_message("a_message"), get_handoff_tool_call(agent_1)], # Third turn: tool call and structured output [ get_function_tool_call("bar", json.dumps({"bar": "baz"})), get_final_output_message(json.dumps(Foo(bar="baz"))), ], ] ) # event_type: (count, event) event_counts: dict[str, int] = {} item_data: list[RunItem] = [] agent_data: list[AgentUpdatedStreamEvent] = [] result = Runner.run_streamed( agent_2, input=[ get_text_input_item("user_message"), get_text_input_item("another_message"), ], ) async for event in result.stream_events(): event_counts[event.type] = event_counts.get(event.type, 0) + 1 if event.type == "run_item_stream_event": item_data.append(event.item) elif event.type == "agent_updated_stream_event": agent_data.append(event) assert result.final_output == Foo(bar="baz") assert len(result.raw_responses) == 3, "should have three model responses" assert len(result.to_input_list()) == 10, ( "should have input: 2 orig inputs, function call, function call result, message, handoff, " "handoff output, tool call, tool call result, final output" ) assert result.last_agent == agent_1, "should have handed off to agent_1" assert result.final_output == Foo(bar="baz"), "should have structured output" # Now lets check the events expected_item_type_map = { "tool_call": 2, "tool_call_output": 2, "message": 2, "handoff": 1, "handoff_output": 1, } total_expected_item_count = sum(expected_item_type_map.values()) assert event_counts["run_item_stream_event"] == total_expected_item_count, ( f"Expected {total_expected_item_count} events, got {event_counts['run_item_stream_event']}" f"Expected events were: {expected_item_type_map}, got {event_counts}" ) assert len(item_data) == total_expected_item_count, ( f"should have {total_expected_item_count} run items" ) assert len(agent_data) == 2, "should have 2 agent updated events" assert agent_data[0].new_agent == agent_2, "should have started with agent_2" assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1"