diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index 489e9b6..bb16cab 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -9,7 +9,7 @@ from agents import Agent, RunConfig, Runner, trace from .fake_model import FakeModel from .test_responses import get_text_message -from .testing_processor import assert_no_traces, fetch_normalized_spans, fetch_traces +from .testing_processor import assert_no_traces, fetch_normalized_spans @pytest.mark.asyncio @@ -193,16 +193,29 @@ async def test_trace_config_works(): await Runner.run( agent, input="first_test", - run_config=RunConfig(workflow_name="Foo bar", group_id="123", trace_id="456"), + run_config=RunConfig(workflow_name="Foo bar", group_id="123", trace_id="trace_456"), ) - traces = fetch_traces() - assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" - export = traces[0].export() - assert export is not None, "Trace export should not be None" - assert export["workflow_name"] == "Foo bar" - assert export["group_id"] == "123" - assert export["id"] == "456" + assert fetch_normalized_spans(keep_trace_id=True) == snapshot( + [ + { + "id": "trace_456", + "workflow_name": "Foo bar", + "group_id": "123", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + } + ], + } + ] + ) @pytest.mark.asyncio @@ -259,8 +272,24 @@ async def test_streaming_single_run_is_single_trace(): async for _ in x.stream_events(): pass - traces = fetch_traces() - assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + } + ], + } + ] + ) @pytest.mark.asyncio @@ -285,8 +314,38 @@ async def test_multiple_streamed_runs_are_multiple_traces(): async for _ in x.stream_events(): pass - traces = fetch_traces() - assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + } + ], + }, + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + } + ], + }, + ] + ) @pytest.mark.asyncio @@ -317,8 +376,42 @@ async def test_wrapped_streaming_trace_is_single_trace(): async for _ in x.stream_events(): pass - traces = fetch_traces() - assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "test_workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + ], + } + ] + ) @pytest.mark.asyncio @@ -347,8 +440,42 @@ async def test_wrapped_mixed_trace_is_single_trace(): async for _ in x.stream_events(): pass - traces = fetch_traces() - assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "test_workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + ], + } + ] + ) @pytest.mark.asyncio diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index 95a960f..40bdfaf 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -7,7 +7,7 @@ from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace from agents.tracing.span_data import ResponseSpanData from tests import fake_model -from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, assert_no_spans +from .testing_processor import assert_no_spans, fetch_normalized_spans, fetch_ordered_spans class DummyTracing: @@ -117,6 +117,7 @@ async def test_disable_tracing_does_not_create_span(monkeypatch): assert_no_spans() + @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_stream_response_creates_trace(monkeypatch): @@ -222,4 +223,4 @@ async def test_stream_disabled_tracing_doesnt_create_span(monkeypatch): assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}]) - assert_no_spans() \ No newline at end of file + assert_no_spans() diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 773d8b9..8f76350 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -20,11 +20,9 @@ from agents.tracing.spans import SpanError from .testing_processor import ( SPAN_PROCESSOR_TESTING, + assert_no_traces, fetch_events, fetch_normalized_spans, - fetch_ordered_spans, - fetch_traces, - assert_no_traces, ) ### HELPERS diff --git a/tests/testing_processor.py b/tests/testing_processor.py index a8e0d40..a38c395 100644 --- a/tests/testing_processor.py +++ b/tests/testing_processor.py @@ -93,14 +93,18 @@ def assert_no_traces(): assert_no_spans() -def fetch_normalized_spans(keep_span_id: bool = False): +def fetch_normalized_spans( + keep_span_id: bool = False, keep_trace_id: bool = False +) -> list[dict[str, Any]]: nodes: dict[tuple[str, str | None], dict[str, Any]] = {} traces = [] for trace_obj in fetch_traces(): trace = trace_obj.export() assert trace assert trace.pop("object") == "trace" - assert trace.pop("id").startswith("trace_") + assert trace["id"].startswith("trace_") + if not keep_trace_id: + del trace["id"] trace = {k: v for k, v in trace.items() if v is not None} nodes[(trace_obj.trace_id, None)] = trace traces.append(trace)