Replace remaining uses of fetch_ordered_spans and fetch_traces for stronger tests (#288)

Following https://github.com/openai/openai-agents-python/pull/261
This commit is contained in:
Rohan Mehta 2025-03-21 13:25:43 -04:00 committed by GitHub
commit 090e79bdf4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 322 additions and 182 deletions

View file

@ -9,7 +9,7 @@ from agents import Agent, RunConfig, Runner, trace
from .fake_model import FakeModel
from .test_responses import get_text_message
from .testing_processor import fetch_normalized_spans, fetch_traces
from .testing_processor import assert_no_traces, fetch_normalized_spans
@pytest.mark.asyncio
@ -164,7 +164,7 @@ async def test_parent_disabled_trace_disabled_agent_trace():
await Runner.run(agent, input="first_test")
assert fetch_normalized_spans() == snapshot([])
assert_no_traces()
@pytest.mark.asyncio
@ -178,7 +178,7 @@ async def test_manual_disabling_works():
await Runner.run(agent, input="first_test", run_config=RunConfig(tracing_disabled=True))
assert fetch_normalized_spans() == snapshot([])
assert_no_traces()
@pytest.mark.asyncio
@ -193,16 +193,29 @@ async def test_trace_config_works():
await Runner.run(
agent,
input="first_test",
run_config=RunConfig(workflow_name="Foo bar", group_id="123", trace_id="456"),
run_config=RunConfig(workflow_name="Foo bar", group_id="123", trace_id="trace_456"),
)
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
export = traces[0].export()
assert export is not None, "Trace export should not be None"
assert export["workflow_name"] == "Foo bar"
assert export["group_id"] == "123"
assert export["id"] == "456"
assert fetch_normalized_spans(keep_trace_id=True) == snapshot(
[
{
"id": "trace_456",
"workflow_name": "Foo bar",
"group_id": "123",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)
@pytest.mark.asyncio
@ -259,8 +272,24 @@ async def test_streaming_single_run_is_single_trace():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)
@pytest.mark.asyncio
@ -285,8 +314,38 @@ async def test_multiple_streamed_runs_are_multiple_traces():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
]
)
@pytest.mark.asyncio
@ -317,8 +376,42 @@ async def test_wrapped_streaming_trace_is_single_trace():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test_workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
],
}
]
)
@pytest.mark.asyncio
@ -347,8 +440,42 @@ async def test_wrapped_mixed_trace_is_single_trace():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test_workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
],
}
]
)
@pytest.mark.asyncio
@ -370,8 +497,7 @@ async def test_parent_disabled_trace_disables_streaming_agent_trace():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert_no_traces()
@pytest.mark.asyncio
@ -392,5 +518,4 @@ async def test_manual_streaming_disabling_works():
async for _ in x.stream_events():
pass
traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert_no_traces()

View file

@ -7,7 +7,7 @@ from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace
from agents.tracing.span_data import ResponseSpanData
from tests import fake_model
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans
from .testing_processor import assert_no_spans, fetch_normalized_spans, fetch_ordered_spans
class DummyTracing:
@ -89,9 +89,8 @@ async def test_non_data_tracing_doesnt_set_response_id(monkeypatch):
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert spans[0].span_data.response is None
[span] = fetch_ordered_spans()
assert span.span_data.response is None
@pytest.mark.allow_call_model_methods
@ -116,8 +115,7 @@ async def test_disable_tracing_does_not_create_span(monkeypatch):
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
spans = fetch_ordered_spans()
assert len(spans) == 0
assert_no_spans()
@pytest.mark.allow_call_model_methods
@ -190,10 +188,9 @@ async def test_stream_non_data_tracing_doesnt_set_response_id(monkeypatch):
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
assert spans[0].span_data.response is None
[span] = fetch_ordered_spans()
assert isinstance(span.span_data, ResponseSpanData)
assert span.span_data.response is None
@pytest.mark.allow_call_model_methods
@ -226,5 +223,4 @@ async def test_stream_disabled_tracing_doesnt_create_span(monkeypatch):
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
spans = fetch_ordered_spans()
assert len(spans) == 0
assert_no_spans()

View file

@ -4,6 +4,7 @@ import asyncio
from typing import Any
import pytest
from inline_snapshot import snapshot
from agents.tracing import (
Span,
@ -17,7 +18,12 @@ from agents.tracing import (
)
from agents.tracing.spans import SpanError
from .testing_processor import fetch_events, fetch_ordered_spans, fetch_traces
from .testing_processor import (
SPAN_PROCESSOR_TESTING,
assert_no_traces,
fetch_events,
fetch_normalized_spans,
)
### HELPERS
@ -47,7 +53,7 @@ def simple_tracing():
x = trace("test")
x.start()
span_1 = agent_span(name="agent_1", parent=x)
span_1 = agent_span(name="agent_1", span_id="span_1", parent=x)
span_1.start()
span_1.finish()
@ -66,33 +72,36 @@ def simple_tracing():
def test_simple_tracing() -> None:
simple_tracing()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 3
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
first_span = spans[0]
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
assert first_span.span_data.name == "agent_1"
second_span = spans[1]
standard_span_checks(second_span, trace_id=trace_id, parent_id=None, span_type="custom")
assert second_span.span_id == "span_2"
assert second_span.span_data.name == "custom_1"
third_span = spans[2]
standard_span_checks(
third_span, trace_id=trace_id, parent_id=second_span.span_id, span_type="custom"
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
[
{
"workflow_name": "test",
"children": [
{
"type": "agent",
"id": "span_1",
"data": {"name": "agent_1"},
},
{
"type": "custom",
"id": "span_2",
"data": {"name": "custom_1", "data": {}},
"children": [
{
"type": "custom",
"id": "span_3",
"data": {"name": "custom_2", "data": {}},
}
],
},
],
}
]
)
assert third_span.span_id == "span_3"
assert third_span.span_data.name == "custom_2"
def ctxmanager_spans():
with trace(workflow_name="test", trace_id="123", group_id="456"):
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
with custom_span(name="custom_1", span_id="span_1"):
with custom_span(name="custom_2", span_id="span_1_inner"):
pass
@ -104,36 +113,38 @@ def ctxmanager_spans():
def test_ctxmanager_spans() -> None:
ctxmanager_spans()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 3
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
first_span = spans[0]
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="custom")
assert first_span.span_id == "span_1"
first_inner_span = spans[1]
standard_span_checks(
first_inner_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="custom"
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
[
{
"workflow_name": "test",
"group_id": "456",
"children": [
{
"type": "custom",
"id": "span_1",
"data": {"name": "custom_1", "data": {}},
"children": [
{
"type": "custom",
"id": "span_1_inner",
"data": {"name": "custom_2", "data": {}},
}
],
},
{"type": "custom", "id": "span_2", "data": {"name": "custom_2", "data": {}}},
],
}
]
)
assert first_inner_span.span_id == "span_1_inner"
second_span = spans[2]
standard_span_checks(second_span, trace_id=trace_id, parent_id=None, span_type="custom")
assert second_span.span_id == "span_2"
async def run_subtask(span_id: str | None = None) -> None:
with generation_span(span_id=span_id):
await asyncio.sleep(0.01)
await asyncio.sleep(0.0001)
async def simple_async_tracing():
with trace(workflow_name="test", trace_id="123", group_id="456"):
with trace(workflow_name="test", trace_id="trace_123", group_id="group_456"):
await run_subtask(span_id="span_1")
await run_subtask(span_id="span_2")
@ -142,21 +153,18 @@ async def simple_async_tracing():
async def test_async_tracing() -> None:
await simple_async_tracing()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 2
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
# We don't care about ordering here, just that they're there
for s in spans:
standard_span_checks(s, trace_id=trace_id, parent_id=None, span_type="generation")
ids = [span.span_id for span in spans]
assert "span_1" in ids
assert "span_2" in ids
assert fetch_normalized_spans(keep_span_id=True) == snapshot(
[
{
"workflow_name": "test",
"group_id": "group_456",
"children": [
{"type": "generation", "id": "span_1"},
{"type": "generation", "id": "span_2"},
],
}
]
)
async def run_tasks_parallel(span_ids: list[str]) -> None:
@ -171,13 +179,11 @@ async def run_tasks_as_children(first_span_id: str, second_span_id: str) -> None
async def complex_async_tracing():
with trace(workflow_name="test", trace_id="123", group_id="456"):
await asyncio.sleep(0.01)
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
await asyncio.gather(
run_tasks_parallel(["span_1", "span_2"]),
run_tasks_parallel(["span_3", "span_4"]),
)
await asyncio.sleep(0.01)
await asyncio.gather(
run_tasks_as_children("span_5", "span_6"),
run_tasks_as_children("span_7", "span_8"),
@ -186,39 +192,38 @@ async def complex_async_tracing():
@pytest.mark.asyncio
async def test_complex_async_tracing() -> None:
await complex_async_tracing()
for _ in range(300):
SPAN_PROCESSOR_TESTING.clear()
await complex_async_tracing()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 8
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
# First ensure 1,2,3,4 exist and are in parallel with the trace as parent
for span_id in ["span_1", "span_2", "span_3", "span_4"]:
span = next((s for s in spans if s.span_id == span_id), None)
assert span is not None
standard_span_checks(span, trace_id=trace_id, parent_id=None, span_type="generation")
# Ensure 5 and 7 exist and have the trace as parent
for span_id in ["span_5", "span_7"]:
span = next((s for s in spans if s.span_id == span_id), None)
assert span is not None
standard_span_checks(span, trace_id=trace_id, parent_id=None, span_type="generation")
# Ensure 6 and 8 exist and have 5 and 7 as parents
six = next((s for s in spans if s.span_id == "span_6"), None)
assert six is not None
standard_span_checks(six, trace_id=trace_id, parent_id="span_5", span_type="generation")
eight = next((s for s in spans if s.span_id == "span_8"), None)
assert eight is not None
standard_span_checks(eight, trace_id=trace_id, parent_id="span_7", span_type="generation")
assert fetch_normalized_spans(keep_span_id=True) == (
[
{
"workflow_name": "test",
"group_id": "456",
"children": [
{"type": "generation", "id": "span_1"},
{"type": "generation", "id": "span_2"},
{"type": "generation", "id": "span_3"},
{"type": "generation", "id": "span_4"},
{
"type": "generation",
"id": "span_5",
"children": [{"type": "generation", "id": "span_6"}],
},
{
"type": "generation",
"id": "span_7",
"children": [{"type": "generation", "id": "span_8"}],
},
],
}
]
)
def spans_with_setters():
with trace(workflow_name="test", trace_id="123", group_id="456"):
with trace(workflow_name="test", trace_id="trace_123", group_id="456"):
with agent_span(name="agent_1") as span_a:
span_a.span_data.name = "agent_2"
@ -236,34 +241,33 @@ def spans_with_setters():
def test_spans_with_setters() -> None:
spans_with_setters()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 4
assert len(traces) == 1
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
# Check the spans
first_span = spans[0]
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
assert first_span.span_data.name == "agent_2"
second_span = spans[1]
standard_span_checks(
second_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="function"
)
assert second_span.span_data.input == "i"
assert second_span.span_data.output == "o"
third_span = spans[2]
standard_span_checks(
third_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="generation"
)
fourth_span = spans[3]
standard_span_checks(
fourth_span, trace_id=trace_id, parent_id=first_span.span_id, span_type="handoff"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"group_id": "456",
"children": [
{
"type": "agent",
"data": {"name": "agent_2"},
"children": [
{
"type": "function",
"data": {"name": "function_1", "input": "i", "output": "o"},
},
{
"type": "generation",
"data": {"input": [{"foo": "bar"}]},
},
{
"type": "handoff",
"data": {"from_agent": "agent_1", "to_agent": "agent_2"},
},
],
}
],
}
]
)
@ -276,14 +280,11 @@ def disabled_tracing():
def test_disabled_tracing():
disabled_tracing()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 0
assert len(traces) == 0
assert_no_traces()
def enabled_trace_disabled_span():
with trace(workflow_name="test", trace_id="123"):
with trace(workflow_name="test", trace_id="trace_123"):
with agent_span(name="agent_1"):
with function_span(name="function_1", disabled=True):
with generation_span():
@ -293,17 +294,19 @@ def enabled_trace_disabled_span():
def test_enabled_trace_disabled_span():
enabled_trace_disabled_span()
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 1 # Only the agent span is recorded
assert len(traces) == 1 # The trace is recorded
trace = traces[0]
standard_trace_checks(trace, name_check="test")
trace_id = trace.trace_id
first_span = spans[0]
standard_span_checks(first_span, trace_id=trace_id, parent_id=None, span_type="agent")
assert first_span.span_data.name == "agent_1"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"children": [
{
"type": "agent",
"data": {"name": "agent_1"},
}
],
}
]
)
def test_start_and_end_called_manual():
@ -367,9 +370,7 @@ async def test_noop_span_doesnt_record():
with custom_span(name="span_1") as span:
span.set_error(SpanError(message="test", data={}))
spans, traces = fetch_ordered_spans(), fetch_traces()
assert len(spans) == 0
assert len(traces) == 0
assert_no_traces()
assert t.export() is None
assert span.export() is None

View file

@ -80,26 +80,44 @@ def fetch_events() -> list[TestSpanProcessorEvent]:
return SPAN_PROCESSOR_TESTING._events
def fetch_normalized_spans():
def assert_no_spans():
spans = fetch_ordered_spans()
if spans:
raise AssertionError(f"Expected 0 spans, got {len(spans)}")
def assert_no_traces():
traces = fetch_traces()
if traces:
raise AssertionError(f"Expected 0 traces, got {len(traces)}")
assert_no_spans()
def fetch_normalized_spans(
keep_span_id: bool = False, keep_trace_id: bool = False
) -> list[dict[str, Any]]:
nodes: dict[tuple[str, str | None], dict[str, Any]] = {}
traces = []
for trace_obj in fetch_traces():
trace = trace_obj.export()
assert trace
assert trace.pop("object") == "trace"
assert trace.pop("id").startswith("trace_")
assert trace["id"].startswith("trace_")
if not keep_trace_id:
del trace["id"]
trace = {k: v for k, v in trace.items() if v is not None}
nodes[(trace_obj.trace_id, None)] = trace
traces.append(trace)
if not traces:
assert not fetch_ordered_spans()
assert traces, "Use assert_no_traces() to check for empty traces"
for span_obj in fetch_ordered_spans():
span = span_obj.export()
assert span
assert span.pop("object") == "trace.span"
assert span.pop("id").startswith("span_")
assert span["id"].startswith("span_")
if not keep_span_id:
del span["id"]
assert datetime.fromisoformat(span.pop("started_at"))
assert datetime.fromisoformat(span.pop("ended_at"))
parent_id = span.pop("parent_id")

View file

@ -1050,7 +1050,7 @@ wheels = [
[[package]]
name = "openai-agents"
version = "0.0.5"
version = "0.0.6"
source = { editable = "." }
dependencies = [
{ name = "griffe" },