Stronger tracing tests with inline-snapshot

This commit is contained in:
Alex Hall 2025-03-11 22:57:14 +02:00
parent c374ad064f
commit c03d314fb8
7 changed files with 875 additions and 5 deletions

View file

@ -47,6 +47,7 @@ dev = [
"mkdocstrings[python]>=0.28.0", "mkdocstrings[python]>=0.28.0",
"coverage>=7.6.12", "coverage>=7.6.12",
"playwright==1.50.0", "playwright==1.50.0",
"inline-snapshot>=0.20.5",
] ]
[tool.uv.workspace] [tool.uv.workspace]
members = ["agents"] members = ["agents"]
@ -116,4 +117,4 @@ filterwarnings = [
] ]
markers = [ markers = [
"allow_call_model_methods: mark test as allowing calls to real model implementations", "allow_call_model_methods: mark test as allowing calls to real model implementations",
] ]

View file

@ -3,12 +3,13 @@ from __future__ import annotations
import asyncio import asyncio
import pytest import pytest
from inline_snapshot import snapshot
from agents import Agent, RunConfig, Runner, trace from agents import Agent, RunConfig, Runner, trace
from .fake_model import FakeModel from .fake_model import FakeModel
from .test_responses import get_text_message from .test_responses import get_text_message
from .testing_processor import fetch_ordered_spans, fetch_traces from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
@pytest.mark.asyncio @pytest.mark.asyncio
@ -25,6 +26,25 @@ async def test_single_run_is_single_trace():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 1, ( assert len(spans) == 1, (
f"Got {len(spans)}, but expected 1: the agent span. data:" f"Got {len(spans)}, but expected 1: the agent span. data:"
@ -52,6 +72,39 @@ async def test_multiple_runs_are_multiple_traces():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}" assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 2, f"Got {len(spans)}, but expected 2: agent span per run" assert len(spans) == 2, f"Got {len(spans)}, but expected 2: agent span per run"
@ -79,6 +132,43 @@ async def test_wrapped_trace_is_single_trace():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test_workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 3, f"Got {len(spans)}, but expected 3: the agent span per run" assert len(spans) == 3, f"Got {len(spans)}, but expected 3: the agent span per run"
@ -97,6 +187,8 @@ async def test_parent_disabled_trace_disabled_agent_trace():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}" assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 0, ( assert len(spans) == 0, (
f"Expected no spans, got {len(spans)}, with {[x.span_data for x in spans]}" f"Expected no spans, got {len(spans)}, with {[x.span_data for x in spans]}"
@ -116,6 +208,8 @@ async def test_manual_disabling_works():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}" assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 0, f"Got {len(spans)}, but expected no spans" assert len(spans) == 0, f"Got {len(spans)}, but expected no spans"
@ -164,6 +258,25 @@ async def test_not_starting_streaming_creates_trace():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 1, f"Got {len(spans)}, but expected 1: the agent span" assert len(spans) == 1, f"Got {len(spans)}, but expected 1: the agent span"

View file

@ -1,4 +1,5 @@
import pytest import pytest
from inline_snapshot import snapshot
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.responses import ResponseCompletedEvent from openai.types.responses import ResponseCompletedEvent
@ -6,7 +7,7 @@ from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace
from agents.tracing.span_data import ResponseSpanData from agents.tracing.span_data import ResponseSpanData
from tests import fake_model from tests import fake_model
from .testing_processor import fetch_ordered_spans from .testing_processor import fetch_normalized_spans, fetch_ordered_spans
class DummyTracing: class DummyTracing:
@ -54,6 +55,15 @@ async def test_get_response_creates_trace(monkeypatch):
"instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED "instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED
) )
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"children": [{"type": "response", "data": {"response_id": "dummy-id"}}],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 1 assert len(spans) == 1
@ -82,6 +92,10 @@ async def test_non_data_tracing_doesnt_set_response_id(monkeypatch):
"instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED_WITHOUT_DATA "instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED_WITHOUT_DATA
) )
assert fetch_normalized_spans() == snapshot(
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 1 assert len(spans) == 1
assert spans[0].span_data.response is None assert spans[0].span_data.response is None
@ -107,6 +121,8 @@ async def test_disable_tracing_does_not_create_span(monkeypatch):
"instr", "input", ModelSettings(), [], None, [], ModelTracing.DISABLED "instr", "input", ModelSettings(), [], None, [], ModelTracing.DISABLED
) )
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 0 assert len(spans) == 0
@ -139,6 +155,15 @@ async def test_stream_response_creates_trace(monkeypatch):
): ):
pass pass
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"children": [{"type": "response", "data": {"response_id": "dummy-id-123"}}],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 1 assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData) assert isinstance(spans[0].span_data, ResponseSpanData)
@ -174,6 +199,10 @@ async def test_stream_non_data_tracing_doesnt_set_response_id(monkeypatch):
): ):
pass pass
assert fetch_normalized_spans() == snapshot(
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 1 assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData) assert isinstance(spans[0].span_data, ResponseSpanData)
@ -208,5 +237,7 @@ async def test_stream_disabled_tracing_doesnt_create_span(monkeypatch):
): ):
pass pass
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 0 assert len(spans) == 0

View file

@ -4,6 +4,7 @@ import json
from typing import Any from typing import Any
import pytest import pytest
from inline_snapshot import snapshot
from typing_extensions import TypedDict from typing_extensions import TypedDict
from agents import ( from agents import (
@ -27,7 +28,7 @@ from .test_responses import (
get_handoff_tool_call, get_handoff_tool_call,
get_text_message, get_text_message,
) )
from .testing_processor import fetch_ordered_spans, fetch_traces from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
@pytest.mark.asyncio @pytest.mark.asyncio
@ -45,6 +46,34 @@ async def test_single_turn_model_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
"children": [
{
"type": "generation",
"error": {
"message": "Error",
"data": {"name": "ValueError", "message": "test error"},
},
}
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}" assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"
@ -80,6 +109,43 @@ async def test_multi_turn_no_handoffs():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": ["foo"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "foo",
"input": '{"a": "b"}',
"output": "tool_result",
},
},
{
"type": "generation",
"error": {
"message": "Error",
"data": {"name": "ValueError", "message": "test error"},
},
},
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 4, ( assert len(spans) == 4, (
f"should have agent, generation, tool, generation, got {len(spans)} with data: " f"should have agent, generation, tool, generation, got {len(spans)} with data: "
@ -110,6 +176,39 @@ async def test_tool_call_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": ["foo"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"error": {
"message": "Error running tool",
"data": {
"tool_name": "foo",
"error": "Invalid JSON input for tool foo: bad_json",
},
},
"data": {"name": "foo", "input": "bad_json"},
},
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 3, ( assert len(spans) == 3, (
f"should have agent, generation, tool spans, got {len(spans)} with data: " f"should have agent, generation, tool spans, got {len(spans)} with data: "
@ -159,6 +258,43 @@ async def test_multiple_handoff_doesnt_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test",
"handoffs": ["test", "test"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{"type": "handoff", "data": {"from_agent": "test", "to_agent": "test"}},
],
},
{
"type": "agent",
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 7, ( assert len(spans) == 7, (
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: " f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
@ -193,6 +329,21 @@ async def test_multiple_final_output_doesnt_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "Foo"},
"children": [{"type": "generation"}],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 2, ( assert len(spans) == 2, (
f"should have 1 agent, 1 generation, got {len(spans)} with data: " f"should have 1 agent, 1 generation, got {len(spans)} with data: "
@ -251,6 +402,76 @@ async def test_handoffs_lead_to_correct_agent_spans():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": ["test_agent_3"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 12, ( assert len(spans) == 12, (
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: " f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
@ -285,6 +506,38 @@ async def test_max_turns_exceeded():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {"message": "Max turns exceeded", "data": {"max_turns": 2}},
"data": {
"name": "test",
"handoffs": [],
"tools": ["foo"],
"output_type": "Foo",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {"name": "foo", "input": "", "output": "result"},
},
{"type": "generation"},
{
"type": "function",
"data": {"name": "foo", "input": "", "output": "result"},
},
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 5, ( assert len(spans) == 5, (
f"should have 1 agent span, 2 generations, 2 function calls, got " f"should have 1 agent span, 2 generations, 2 function calls, got "
@ -318,6 +571,30 @@ async def test_guardrail_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {
"message": "Guardrail tripwire triggered",
"data": {"guardrail": "guardrail_function"},
},
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [
{
"type": "guardrail",
"data": {"name": "guardrail_function", "triggered": True},
}
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 2, ( assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: " f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "

View file

@ -5,6 +5,7 @@ import json
from typing import Any from typing import Any
import pytest import pytest
from inline_snapshot import snapshot
from typing_extensions import TypedDict from typing_extensions import TypedDict
from agents import ( from agents import (
@ -32,7 +33,7 @@ from .test_responses import (
get_handoff_tool_call, get_handoff_tool_call,
get_text_message, get_text_message,
) )
from .testing_processor import fetch_ordered_spans, fetch_traces from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
@pytest.mark.asyncio @pytest.mark.asyncio
@ -52,6 +53,35 @@ async def test_single_turn_model_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {"message": "Error in agent run", "data": {"error": "test error"}},
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
"children": [
{
"type": "generation",
"error": {
"message": "Error",
"data": {"name": "ValueError", "message": "test error"},
},
}
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}" assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"
@ -89,6 +119,44 @@ async def test_multi_turn_no_handoffs():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {"message": "Error in agent run", "data": {"error": "test error"}},
"data": {
"name": "test_agent",
"handoffs": [],
"tools": ["foo"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "foo",
"input": '{"a": "b"}',
"output": "tool_result",
},
},
{
"type": "generation",
"error": {
"message": "Error",
"data": {"name": "ValueError", "message": "test error"},
},
},
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 4, ( assert len(spans) == 4, (
f"should have agent, generation, tool, generation, got {len(spans)} with data: " f"should have agent, generation, tool, generation, got {len(spans)} with data: "
@ -121,6 +189,43 @@ async def test_tool_call_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {
"message": "Error in agent run",
"data": {"error": "Invalid JSON input for tool foo: bad_json"},
},
"data": {
"name": "test_agent",
"handoffs": [],
"tools": ["foo"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"error": {
"message": "Error running tool",
"data": {
"tool_name": "foo",
"error": "Invalid JSON input for tool foo: bad_json",
},
},
"data": {"name": "foo", "input": "bad_json"},
},
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 3, ( assert len(spans) == 3, (
f"should have agent, generation, tool spans, got {len(spans)} with data: " f"should have agent, generation, tool spans, got {len(spans)} with data: "
@ -173,6 +278,43 @@ async def test_multiple_handoff_doesnt_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test",
"handoffs": ["test", "test"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{"type": "handoff", "data": {"from_agent": "test", "to_agent": "test"}},
],
},
{
"type": "agent",
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 7, ( assert len(spans) == 7, (
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: " f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
@ -211,6 +353,21 @@ async def test_multiple_final_output_no_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "Foo"},
"children": [{"type": "generation"}],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 2, ( assert len(spans) == 2, (
f"should have 1 agent, 1 generation, got {len(spans)} with data: " f"should have 1 agent, 1 generation, got {len(spans)} with data: "
@ -271,12 +428,152 @@ async def test_handoffs_lead_to_correct_agent_spans():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": ["test_agent_3"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 12, ( assert len(spans) == 12, (
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: " f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}" f"{[x.span_data for x in spans]}"
) )
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": ["test_agent_3"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [{"type": "generation"}],
},
],
}
]
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_max_turns_exceeded(): async def test_max_turns_exceeded():
@ -307,6 +604,38 @@ async def test_max_turns_exceeded():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {"message": "Max turns exceeded", "data": {"max_turns": 2}},
"data": {
"name": "test",
"handoffs": [],
"tools": ["foo"],
"output_type": "Foo",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {"name": "foo", "input": "", "output": "result"},
},
{"type": "generation"},
{
"type": "function",
"data": {"name": "foo", "input": "", "output": "result"},
},
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 5, ( assert len(spans) == 5, (
f"should have 1 agent, 2 generations, 2 function calls, got " f"should have 1 agent, 2 generations, 2 function calls, got "
@ -347,6 +676,33 @@ async def test_input_guardrail_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {
"message": "Guardrail tripwire triggered",
"data": {
"guardrail": "input_guardrail_function",
"type": "input_guardrail",
},
},
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [
{
"type": "guardrail",
"data": {"name": "input_guardrail_function", "triggered": True},
}
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 2, ( assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: " f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
@ -387,6 +743,30 @@ async def test_output_guardrail_error():
traces = fetch_traces() traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {
"message": "Guardrail tripwire triggered",
"data": {"guardrail": "output_guardrail_function"},
},
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [
{
"type": "guardrail",
"data": {"name": "output_guardrail_function", "triggered": True},
}
],
}
],
}
]
)
spans = fetch_ordered_spans() spans = fetch_ordered_spans()
assert len(spans) == 2, ( assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: " f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import threading import threading
from datetime import datetime
from typing import Any, Literal from typing import Any, Literal
from agents.tracing import Span, Trace, TracingProcessor from agents.tracing import Span, Trace, TracingProcessor
@ -77,3 +78,35 @@ def fetch_traces() -> list[Trace]:
def fetch_events() -> list[TestSpanProcessorEvent]: def fetch_events() -> list[TestSpanProcessorEvent]:
return SPAN_PROCESSOR_TESTING._events return SPAN_PROCESSOR_TESTING._events
def fetch_normalized_spans():
nodes: dict[tuple[str, str | None], dict[str, Any]] = {}
traces = []
for trace_obj in fetch_traces():
trace = trace_obj.export()
assert trace.pop("object") == "trace"
assert trace.pop("id").startswith("trace_")
trace = {k: v for k, v in trace.items() if v is not None}
nodes[(trace_obj.trace_id, None)] = trace
traces.append(trace)
if not traces:
assert not fetch_ordered_spans()
for span_obj in fetch_ordered_spans():
span = span_obj.export()
assert span.pop("object") == "trace.span"
assert span.pop("id").startswith("span_")
assert datetime.fromisoformat(span.pop("started_at"))
assert datetime.fromisoformat(span.pop("ended_at"))
parent_id = span.pop("parent_id")
assert "type" not in span
span_data = span.pop("span_data")
span = {"type": span_data.pop("type")} | {k: v for k, v in span.items() if v is not None}
span_data = {k: v for k, v in span_data.items() if v is not None}
if span_data:
span["data"] = span_data
nodes[(span_obj.trace_id, span_obj.span_id)] = span
nodes[(span.pop("trace_id"), parent_id)].setdefault("children", []).append(span)
return traces

35
uv.lock
View file

@ -26,6 +26,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 }, { url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 },
] ]
[[package]]
name = "asttokens"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 },
]
[[package]] [[package]]
name = "babel" name = "babel"
version = "2.17.0" version = "2.17.0"
@ -240,6 +249,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 },
] ]
[[package]]
name = "executing"
version = "2.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 },
]
[[package]] [[package]]
name = "ghp-import" name = "ghp-import"
version = "2.1.0" version = "2.1.0"
@ -392,6 +410,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
] ]
[[package]]
name = "inline-snapshot"
version = "0.20.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "asttokens" },
{ name = "executing" },
{ name = "rich" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3b/95/9b85a63031c168dd1c479f8cfd5cae42d42d6ac41c18dd760a104bc87ddc/inline_snapshot-0.20.5.tar.gz", hash = "sha256:d8b67c6d533c0a3f566e72608144b54da65dc3da5d0dba4169b2c56b75530fb5", size = 92215 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/71/34e775bbf0bcf81d588d80a1df93437f937b0df9a841f246606a03fc5eff/inline_snapshot-0.20.5-py3-none-any.whl", hash = "sha256:3aa56acf5985d89f17ebd4df4aef00faacc49f10cdf4e6b42be701ffc9702b5a", size = 48071 },
]
[[package]] [[package]]
name = "jinja2" name = "jinja2"
version = "3.1.6" version = "3.1.6"
@ -797,6 +830,7 @@ dependencies = [
[package.dev-dependencies] [package.dev-dependencies]
dev = [ dev = [
{ name = "coverage" }, { name = "coverage" },
{ name = "inline-snapshot" },
{ name = "mkdocs" }, { name = "mkdocs" },
{ name = "mkdocs-material" }, { name = "mkdocs-material" },
{ name = "mkdocstrings", extra = ["python"] }, { name = "mkdocstrings", extra = ["python"] },
@ -822,6 +856,7 @@ requires-dist = [
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [
{ name = "coverage", specifier = ">=7.6.12" }, { name = "coverage", specifier = ">=7.6.12" },
{ name = "inline-snapshot", specifier = ">=0.20.5" },
{ name = "mkdocs", specifier = ">=1.6.0" }, { name = "mkdocs", specifier = ">=1.6.0" },
{ name = "mkdocs-material", specifier = ">=9.6.0" }, { name = "mkdocs-material", specifier = ">=9.6.0" },
{ name = "mkdocstrings", extras = ["python"], specifier = ">=0.28.0" }, { name = "mkdocstrings", extras = ["python"], specifier = ">=0.28.0" },