Merge branch 'main' into patch-1

This commit is contained in:
Vincent Koc 2025-03-20 02:57:34 +11:00 committed by GitHub
commit e7c2c19564
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 1430 additions and 144 deletions

View file

@ -0,0 +1,26 @@
---
name: Custom model providers
about: Questions or bugs about using non-OpenAI models
title: ''
labels: bug
assignees: ''
---
### Please read this first
- **Have you read the custom model provider docs, including the 'Common issues' section?** [Model provider docs](https://openai.github.io/openai-agents-python/models/#using-other-llm-providers)
- **Have you searched for related issues?** Others may have faced similar issues.
### Describe the question
A clear and concise description of what the question or bug is.
### Debug information
- Agents SDK version: (e.g. `v0.0.3`)
- Python version (e.g. Python 3.10)
### Repro steps
Ideally provide a minimal python script that can be run to reproduce the issue.
### Expected behavior
A clear and concise description of what you expected to happen.

View file

@ -50,8 +50,8 @@ jobs:
enable-cache: true
- name: Install dependencies
run: make sync
- name: Run tests
run: make tests
- name: Run tests with coverage
run: make coverage
build-docs:
runs-on: ubuntu-latest

View file

@ -18,6 +18,21 @@ mypy:
tests:
uv run pytest
.PHONY: coverage
coverage:
uv run coverage run -m pytest
uv run coverage xml -o coverage.xml
uv run coverage report -m --fail-under=95
.PHONY: snapshots-fix
snapshots-fix:
uv run pytest --inline-snapshot=fix
.PHONY: snapshots-create
snapshots-create:
uv run pytest --inline-snapshot=create
.PHONY: old_version_tests
old_version_tests:
UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest
@ -35,3 +50,5 @@ serve-docs:
deploy-docs:
uv run mkdocs gh-deploy --force --verbose

View file

@ -7,7 +7,7 @@ The OpenAI Agents SDK is a lightweight yet powerful framework for building multi
### Core concepts:
1. [**Agents**](https://openai.github.io/openai-agents-python/agents): LLMs configured with instructions, tools, guardrails, and handoffs
2. [**Handoffs**](https://openai.github.io/openai-agents-python/handoffs/): Allow agents to transfer control to other agents for specific tasks
2. [**Handoffs**](https://openai.github.io/openai-agents-python/handoffs/): A specialized tool call used by the Agents SDK for transferring control between agents
3. [**Guardrails**](https://openai.github.io/openai-agents-python/guardrails/): Configurable safety checks for input and output validation
4. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows
@ -142,15 +142,7 @@ The Agents SDK is designed to be highly flexible, allowing you to model a wide r
## Tracing
The Agents SDK automatically traces your agent runs, making it easy to track and debug the behavior of your agents. Tracing is extensible by design, supporting custom spans and a wide variety of external destinations, including:
- [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk)
- [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk)
- [Comet Opik](https://www.comet.com/docs/opik/tracing/integrations/openai_agents)
- [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent)
- [Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents)
- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration)
For more details about how to customize or disable tracing, see [Tracing](http://openai.github.io/openai-agents-python/tracing).
The Agents SDK automatically traces your agent runs, making it easy to track and debug the behavior of your agents. Tracing is extensible by design, supporting custom spans and a wide variety of external destinations, including [Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents), [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk), [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk), [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration), and [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent). For more details about how to customize or disable tracing, see [Tracing](http://openai.github.io/openai-agents-python/tracing), which also includes a larger list of [external tracing processors](http://openai.github.io/openai-agents-python/tracing/#external-tracing-processors-list).
## Development (only needed if you need to edit the SDK/examples)

View file

@ -111,8 +111,8 @@ class MessageOutput(BaseModel): # (1)!
response: str
class MathOutput(BaseModel): # (2)!
is_math: bool
reasoning: str
is_math: bool
guardrail_agent = Agent(
name="Guardrail check",

View file

@ -9,6 +9,8 @@ The Agents SDK includes built-in tracing, collecting a comprehensive record of e
1. You can globally disable tracing by setting the env var `OPENAI_AGENTS_DISABLE_TRACING=1`
2. You can disable tracing for a single run by setting [`agents.run.RunConfig.tracing_disabled`][] to `True`
***For organizations operating under a Zero Data Retention (ZDR) policy using OpenAI's APIs, tracing is unavailable.***
## Traces and spans
- **Traces** represent a single end-to-end operation of a "workflow". They're composed of Spans. Traces have the following properties:
@ -88,11 +90,15 @@ To customize this default setup, to send traces to alternative or additional bac
1. [`add_trace_processor()`][agents.tracing.add_trace_processor] lets you add an **additional** trace processor that will receive traces and spans as they are ready. This lets you do your own processing in addition to sending traces to OpenAI's backend.
2. [`set_trace_processors()`][agents.tracing.set_trace_processors] lets you **replace** the default processors with your own trace processors. This means traces will not be sent to the OpenAI backend unless you include a `TracingProcessor` that does so.
External trace processors include:
## External tracing processors list
- [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk)
- [Arize-Phoenix](https://docs.arize.com/phoenix/tracing/integrations-tracing/openai-agents-sdk)
- [MLflow](https://mlflow.org/docs/latest/tracing/integrations/openai-agent)
- [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk)
- [Comet Opik](https://www.comet.com/docs/opik/tracing/integrations/openai_agents)
- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration))
- [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent)
- [Pydantic Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents)
- [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk)
- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration)
- [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent)
- [LangSmith](https://docs.smith.langchain.com/observability/how_to_guides/trace_with_openai_agents_sdk)
- [Maxim AI](https://www.getmaxim.ai/docs/observe/integrations/openai-agents-sdk)
- [Comet Opik](https://www.comet.com/docs/opik/tracing/integrations/openai_agents)

View file

@ -30,8 +30,8 @@ If the guardrail trips, we'll respond with a refusal message.
### 1. An agent-based guardrail that is triggered if the user is asking to do math homework
class MathHomeworkOutput(BaseModel):
is_math_homework: bool
reasoning: str
is_math_homework: bool
guardrail_agent = Agent(

View file

@ -23,8 +23,8 @@ story_outline_generator = Agent(
@dataclass
class EvaluationFeedback:
score: Literal["pass", "needs_improvement", "fail"]
feedback: str
score: Literal["pass", "needs_improvement", "fail"]
evaluator = Agent[None](

View file

@ -74,7 +74,7 @@ multiply_agent = Agent(
start_agent = Agent(
name="Start Agent",
instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multipler agent.",
instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multiply agent.",
tools=[random_number],
output_type=FinalResult,
handoffs=[multiply_agent],

View file

@ -3,7 +3,7 @@ from agents import Agent, Runner
agent = Agent(name="Assistant", instructions="You are a helpful assistant")
# Intended for Jupyter notebooks where there's an existing event loop
result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704
result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704
print(result.final_output)
# Code within code loops,

View file

@ -60,9 +60,9 @@ async def main():
print("Step 1 done")
# 2. Ask it to square a number
# 2. Ask it to generate a number
result = await Runner.run(
second_agent,
first_agent,
input=result.to_input_list()
+ [{"content": "Can you generate a random number between 0 and 100?", "role": "user"}],
)

View file

@ -60,9 +60,9 @@ async def main():
print("Step 1 done")
# 2. Ask it to square a number
# 2. Ask it to generate a number
result = await Runner.run(
second_agent,
first_agent,
input=result.to_input_list()
+ [{"content": "Can you generate a random number between 0 and 100?", "role": "user"}],
)

View file

@ -47,6 +47,7 @@ dev = [
"mkdocstrings[python]>=0.28.0",
"coverage>=7.6.12",
"playwright==1.50.0",
"inline-snapshot>=0.20.7",
]
[tool.uv.workspace]
members = ["agents"]
@ -117,3 +118,6 @@ filterwarnings = [
markers = [
"allow_call_model_methods: mark test as allowing calls to real model implementations",
]
[tool.inline-snapshot]
format-command="ruff format --stdin-filename {filename}"

View file

@ -73,6 +73,7 @@ from .tracing import (
SpanData,
SpanError,
Trace,
TracingProcessor,
add_trace_processor,
agent_span,
custom_span,
@ -129,10 +130,9 @@ def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> Non
def enable_verbose_stdout_logging():
"""Enables verbose logging to stdout. This is useful for debugging."""
for name in ["openai.agents", "openai.agents.tracing"]:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
logger = logging.getLogger("openai.agents")
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
__all__ = [
@ -209,6 +209,7 @@ __all__ = [
"set_tracing_disabled",
"trace",
"Trace",
"TracingProcessor",
"SpanError",
"Span",
"SpanData",

View file

@ -25,7 +25,6 @@ from openai.types.responses.response_computer_tool_call import (
from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from . import _utils
from .agent import Agent
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Computer
@ -59,6 +58,7 @@ from .tracing import (
handoff_span,
trace,
)
from .util import _coro, _error_tracing
if TYPE_CHECKING:
from .run import RunConfig
@ -293,7 +293,7 @@ class RunImpl:
elif isinstance(output, ResponseComputerToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
if not computer_tool:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Computer tool not found",
data={},
@ -324,7 +324,7 @@ class RunImpl:
# Regular function tool call
else:
if output.name not in function_map:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Tool not found",
data={"tool_name": output.name},
@ -368,7 +368,7 @@ class RunImpl:
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
)
@ -378,11 +378,11 @@ class RunImpl:
(
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)
except Exception as e:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": func_tool.name, "error": str(e)},
@ -502,7 +502,7 @@ class RunImpl:
source=agent,
)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)
@ -520,7 +520,7 @@ class RunImpl:
new_items=tuple(new_step_items),
)
if not callable(input_filter):
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
span_handoff,
SpanError(
message="Invalid input filter",
@ -530,7 +530,7 @@ class RunImpl:
raise UserError(f"Invalid input filter: {input_filter}")
filtered = input_filter(handoff_input_data)
if not isinstance(filtered, HandoffInputData):
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
span_handoff,
SpanError(
message="Invalid input filter result",
@ -591,7 +591,7 @@ class RunImpl:
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _utils.noop_coroutine(),
else _coro.noop_coroutine(),
)
@classmethod
@ -706,7 +706,7 @@ class ComputerAction:
(
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
output_func,
)
@ -716,7 +716,7 @@ class ComputerAction:
(
agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)

View file

@ -1,61 +0,0 @@
from __future__ import annotations
import re
from collections.abc import Awaitable
from typing import Any, Literal, Union
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeVar
from .exceptions import ModelBehaviorError
from .logger import logger
from .tracing import Span, SpanError, get_current_span
T = TypeVar("T")
MaybeAwaitable = Union[Awaitable[T], T]
def transform_string_function_style(name: str) -> str:
# Replace spaces with underscores
name = name.replace(" ", "_")
# Replace non-alphanumeric characters with underscores
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
return name.lower()
def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T:
partial_setting: bool | Literal["off", "on", "trailing-strings"] = (
"trailing-strings" if partial else False
)
try:
validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting)
return validated
except ValidationError as e:
attach_error_to_current_span(
SpanError(
message="Invalid JSON provided",
data={},
)
)
raise ModelBehaviorError(
f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}"
) from e
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
span.set_error(error)
def attach_error_to_current_span(error: SpanError) -> None:
span = get_current_span()
if span:
attach_error_to_span(span, error)
else:
logger.warning(f"No span to add error {error} to")
async def noop_coroutine() -> None:
pass

View file

@ -6,8 +6,6 @@ from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
from . import _utils
from ._utils import MaybeAwaitable
from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
from .items import ItemHelpers
@ -16,6 +14,8 @@ from .model_settings import ModelSettings
from .models.interface import Model
from .run_context import RunContextWrapper, TContext
from .tool import Tool, function_tool
from .util import _transforms
from .util._types import MaybeAwaitable
if TYPE_CHECKING:
from .lifecycle import AgentHooks
@ -27,8 +27,8 @@ class Agent(Generic[TContext]):
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
addition, you can pass `description`, which is a human-readable description of the agent, used
when the agent is used inside tools/handoffs.
addition, you can pass `handoff_description`, which is a human-readable description of the
agent, used when the agent is used inside tools/handoffs.
Agents are generic on the context type. The context is a (mutable) object you create. It is
passed to tool functions, handoffs, guardrails, etc.
@ -126,7 +126,7 @@ class Agent(Generic[TContext]):
"""
@function_tool(
name_override=tool_name or _utils.transform_string_function_style(self.name),
name_override=tool_name or _transforms.transform_string_function_style(self.name),
description_override=tool_description or "",
)
async def run_agent(context: RunContextWrapper, input: str) -> str:

View file

@ -4,10 +4,10 @@ from typing import Any
from pydantic import BaseModel, TypeAdapter
from typing_extensions import TypedDict, get_args, get_origin
from . import _utils
from .exceptions import ModelBehaviorError, UserError
from .strict_schema import ensure_strict_json_schema
from .tracing import SpanError
from .util import _error_tracing, _json
_WRAPPER_DICT_KEY = "response"
@ -87,10 +87,10 @@ class AgentOutputSchema:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _utils.validate_json(json_str, self._type_adapter, partial)
validated = _json.validate_json(json_str, self._type_adapter, partial)
if self._is_wrapped:
if not isinstance(validated, dict):
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Expected a dict, got {type(validated)}"},
@ -101,7 +101,7 @@ class AgentOutputSchema:
)
if _WRAPPER_DICT_KEY not in validated:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},

View file

@ -33,6 +33,9 @@ class FuncSchema:
"""The signature of the function."""
takes_context: bool = False
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
@ -337,4 +340,5 @@ def function_schema(
params_json_schema=json_schema,
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
)

View file

@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload
from typing_extensions import TypeVar
from ._utils import MaybeAwaitable
from .exceptions import UserError
from .items import TResponseInputItem
from .run_context import RunContextWrapper, TContext
from .util._types import MaybeAwaitable
if TYPE_CHECKING:
from .agent import Agent

View file

@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
from pydantic import TypeAdapter
from typing_extensions import TypeAlias, TypeVar
from . import _utils
from .exceptions import ModelBehaviorError, UserError
from .items import RunItem, TResponseInputItem
from .run_context import RunContextWrapper, TContext
from .strict_schema import ensure_strict_json_schema
from .tracing.spans import SpanError
from .util import _error_tracing, _json, _transforms
if TYPE_CHECKING:
from .agent import Agent
@ -104,7 +104,7 @@ class Handoff(Generic[TContext]):
@classmethod
def default_tool_name(cls, agent: Agent[Any]) -> str:
return _utils.transform_string_function_style(f"transfer_to_{agent.name}")
return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")
@classmethod
def default_tool_description(cls, agent: Agent[Any]) -> str:
@ -192,7 +192,7 @@ def handoff(
) -> Agent[Any]:
if input_type is not None and type_adapter is not None:
if input_json is None:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Handoff function expected non-null input, but got None",
data={"details": "input_json is None"},
@ -200,7 +200,7 @@ def handoff(
)
raise ModelBehaviorError("Handoff function expected non-null input, but got None")
validated_input = _utils.validate_json(
validated_input = _json.validate_json(
json_str=input_json,
type_adapter=type_adapter,
partial=False,

View file

@ -17,6 +17,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
from .stream_events import StreamEvent
from .tracing import Trace
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
if TYPE_CHECKING:
from ._run_impl import QueueCompleteSentinel
@ -89,6 +90,9 @@ class RunResult(RunResultBase):
"""The last agent that was run."""
return self._last_agent
def __str__(self) -> str:
return pretty_print_result(self)
@dataclass
class RunResultStreaming(RunResultBase):
@ -216,3 +220,6 @@ class RunResultStreaming(RunResultBase):
if self._output_guardrails_task and not self._output_guardrails_task.done():
self._output_guardrails_task.cancel()
def __str__(self) -> str:
return pretty_print_run_result_streaming(self)

View file

@ -7,7 +7,6 @@ from typing import Any, cast
from openai.types.responses import ResponseCompletedEvent
from . import Model, _utils
from ._run_impl import (
NextStepFinalOutput,
NextStepHandoff,
@ -33,7 +32,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .lifecycle import RunHooks
from .logger import logger
from .model_settings import ModelSettings
from .models.interface import ModelProvider
from .models.interface import Model, ModelProvider
from .models.openai_provider import OpenAIProvider
from .result import RunResult, RunResultStreaming
from .run_context import RunContextWrapper, TContext
@ -41,6 +40,7 @@ from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
from .tracing.span_data import AgentSpanData
from .usage import Usage
from .util import _coro, _error_tracing
DEFAULT_MAX_TURNS = 10
@ -193,7 +193,7 @@ class Runner:
current_turn += 1
if current_turn > max_turns:
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
current_span,
SpanError(
message="Max turns exceeded",
@ -447,7 +447,7 @@ class Runner:
for done in asyncio.as_completed(guardrail_tasks):
result = await done
if result.output.tripwire_triggered:
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
parent_span,
SpanError(
message="Guardrail tripwire triggered",
@ -511,7 +511,7 @@ class Runner:
streamed_result.current_turn = current_turn
if current_turn > max_turns:
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
current_span,
SpanError(
message="Max turns exceeded",
@ -583,7 +583,7 @@ class Runner:
pass
except Exception as e:
if current_span:
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
current_span,
SpanError(
message="Error in agent run",
@ -615,7 +615,7 @@ class Runner:
(
agent.hooks.on_start(context_wrapper, agent)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)
@ -705,7 +705,7 @@ class Runner:
(
agent.hooks.on_start(context_wrapper, agent)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)
@ -796,7 +796,7 @@ class Runner:
# Cancel all guardrail tasks if a tripwire is triggered.
for t in guardrail_tasks:
t.cancel()
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Guardrail tripwire triggered",
data={"guardrail": result.guardrail.get_name()},
@ -834,7 +834,7 @@ class Runner:
# Cancel all guardrail tasks if a tripwire is triggered.
for t in guardrail_tasks:
t.cancel()
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Guardrail tripwire triggered",
data={"guardrail": result.guardrail.get_name()},

View file

@ -11,14 +11,15 @@ from openai.types.responses.web_search_tool_param import UserLocation
from pydantic import ValidationError
from typing_extensions import Concatenate, ParamSpec
from . import _debug, _utils
from ._utils import MaybeAwaitable
from . import _debug
from .computer import AsyncComputer, Computer
from .exceptions import ModelBehaviorError
from .function_schema import DocstringStyle, function_schema
from .logger import logger
from .run_context import RunContextWrapper
from .tracing import SpanError
from .util import _error_tracing
from .util._types import MaybeAwaitable
ToolParams = ParamSpec("ToolParams")
@ -137,6 +138,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
@ -150,6 +152,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
@ -163,6 +166,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
@ -186,6 +190,11 @@ def function_tool(
failure_error_function: If provided, use this function to generate an error message when
the tool call fails. The error message is sent to the LLM. If you pass None, then no
error message will be sent and instead an Exception will be raised.
strict_mode: Whether to enable strict mode for the tool's JSON schema. We *strongly*
recommend setting this to True, as it increases the likelihood of correct JSON input.
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
value, it will be optional, additional properties are allowed, etc. See here for more:
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
"""
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@ -195,6 +204,7 @@ def function_tool(
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
)
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
@ -257,7 +267,7 @@ def function_tool(
if inspect.isawaitable(result):
return await result
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool (non-fatal)",
data={
@ -273,6 +283,7 @@ def function_tool(
description=schema.description or "",
params_json_schema=schema.params_json_schema,
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
)
# If func is actually a callable, we were used as @function_tool with no parentheses

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from .logger import logger
from ..logger import logger
from .setup import GLOBAL_TRACE_PROVIDER
from .span_data import (
AgentSpanData,

View file

@ -9,7 +9,7 @@ from typing import Any
import httpx
from .logger import logger
from ..logger import logger
from .processor_interface import TracingExporter, TracingProcessor
from .spans import Span
from .traces import Trace

View file

@ -2,7 +2,7 @@
import contextvars
from typing import TYPE_CHECKING, Any
from .logger import logger
from ..logger import logger
if TYPE_CHECKING:
from .spans import Span

View file

@ -4,8 +4,8 @@ import os
import threading
from typing import Any
from ..logger import logger
from . import util
from .logger import logger
from .processor_interface import TracingProcessor
from .scope import Scope
from .spans import NoOpSpan, Span, SpanImpl, TSpanData

View file

@ -6,8 +6,8 @@ from typing import Any, Generic, TypeVar
from typing_extensions import TypedDict
from ..logger import logger
from . import util
from .logger import logger
from .processor_interface import TracingProcessor
from .scope import Scope
from .span_data import SpanData

View file

@ -4,8 +4,8 @@ import abc
import contextvars
from typing import Any
from ..logger import logger
from . import util
from .logger import logger
from .processor_interface import TracingProcessor
from .scope import Scope

View file

2
src/agents/util/_coro.py Normal file
View file

@ -0,0 +1,2 @@
async def noop_coroutine() -> None:
pass

View file

@ -0,0 +1,16 @@
from typing import Any
from ..logger import logger
from ..tracing import Span, SpanError, get_current_span
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
span.set_error(error)
def attach_error_to_current_span(error: SpanError) -> None:
span = get_current_span()
if span:
attach_error_to_span(span, error)
else:
logger.warning(f"No span to add error {error} to")

31
src/agents/util/_json.py Normal file
View file

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Literal
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeVar
from ..exceptions import ModelBehaviorError
from ..tracing import SpanError
from ._error_tracing import attach_error_to_current_span
T = TypeVar("T")
def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T:
partial_setting: bool | Literal["off", "on", "trailing-strings"] = (
"trailing-strings" if partial else False
)
try:
validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting)
return validated
except ValidationError as e:
attach_error_to_current_span(
SpanError(
message="Invalid JSON provided",
data={},
)
)
raise ModelBehaviorError(
f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}"
) from e

View file

@ -0,0 +1,56 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from ..result import RunResult, RunResultBase, RunResultStreaming
def _indent(text: str, indent_level: int) -> str:
indent_string = " " * indent_level
return "\n".join(f"{indent_string}{line}" for line in text.splitlines())
def _final_output_str(result: "RunResultBase") -> str:
if result.final_output is None:
return "None"
elif isinstance(result.final_output, str):
return result.final_output
elif isinstance(result.final_output, BaseModel):
return result.final_output.model_dump_json(indent=2)
else:
return str(result.final_output)
def pretty_print_result(result: "RunResult") -> str:
output = "RunResult:"
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
output += (
f"\n- Final output ({type(result.final_output).__name__}):\n"
f"{_indent(_final_output_str(result), 2)}"
)
output += f"\n- {len(result.new_items)} new item(s)"
output += f"\n- {len(result.raw_responses)} raw response(s)"
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
output += f"\n- {len(result.output_guardrail_results)} output guardrail result(s)"
output += "\n(See `RunResult` for more details)"
return output
def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
output = "RunResultStreaming:"
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'
output += f"\n- Current turn: {result.current_turn}"
output += f"\n- Max turns: {result.max_turns}"
output += f"\n- Is complete: {result.is_complete}"
output += (
f"\n- Final output ({type(result.final_output).__name__}):\n"
f"{_indent(_final_output_str(result), 2)}"
)
output += f"\n- {len(result.new_items)} new item(s)"
output += f"\n- {len(result.raw_responses)} raw response(s)"
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
output += f"\n- {len(result.output_guardrail_results)} output guardrail result(s)"
output += "\n(See `RunResultStreaming` for more details)"
return output

View file

@ -0,0 +1,11 @@
import re
def transform_string_function_style(name: str) -> str:
# Replace spaces with underscores
name = name.replace(" ", "_")
# Replace non-alphanumeric characters with underscores
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
return name.lower()

View file

@ -0,0 +1,7 @@
from collections.abc import Awaitable
from typing import Union
from typing_extensions import TypeVar
T = TypeVar("T")
MaybeAwaitable = Union[Awaitable[T], T]

25
tests/README.md Normal file
View file

@ -0,0 +1,25 @@
# Tests
Before running any tests, make sure you have `uv` installed (and ideally run `make sync` after).
## Running tests
```
make tests
```
## Snapshots
We use [inline-snapshots](https://15r10nk.github.io/inline-snapshot/latest/) for some tests. If your code adds new snapshot tests or breaks existing ones, you can fix/create them. After fixing/creating snapshots, run `make tests` again to verify the tests pass.
### Fixing snapshots
```
make snapshots-fix
```
### Creating snapshots
```
make snapshots-update
```

View file

@ -3,12 +3,13 @@ from __future__ import annotations
import asyncio
import pytest
from inline_snapshot import snapshot
from agents import Agent, RunConfig, Runner, trace
from .fake_model import FakeModel
from .test_responses import get_text_message
from .testing_processor import fetch_ordered_spans, fetch_traces
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
@pytest.mark.asyncio
@ -25,6 +26,25 @@ async def test_single_run_is_single_trace():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 1, (
f"Got {len(spans)}, but expected 1: the agent span. data:"
@ -52,6 +72,39 @@ async def test_multiple_runs_are_multiple_traces():
traces = fetch_traces()
assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
},
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, f"Got {len(spans)}, but expected 2: agent span per run"
@ -79,6 +132,43 @@ async def test_wrapped_trace_is_single_trace():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test_workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": [],
"tools": [],
"output_type": "str",
},
},
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 3, f"Got {len(spans)}, but expected 3: the agent span per run"
@ -97,6 +187,8 @@ async def test_parent_disabled_trace_disabled_agent_trace():
traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])
spans = fetch_ordered_spans()
assert len(spans) == 0, (
f"Expected no spans, got {len(spans)}, with {[x.span_data for x in spans]}"
@ -116,6 +208,8 @@ async def test_manual_disabling_works():
traces = fetch_traces()
assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}"
assert fetch_normalized_spans() == snapshot([])
spans = fetch_ordered_spans()
assert len(spans) == 0, f"Got {len(spans)}, but expected no spans"
@ -164,6 +258,25 @@ async def test_not_starting_streaming_creates_trace():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 1, f"Got {len(spans)}, but expected 1: the agent span"

View file

@ -1,6 +1,6 @@
import asyncio
import json
from typing import Any
from typing import Any, Optional
import pytest
@ -142,3 +142,59 @@ async def test_no_error_on_invalid_json_async():
tool = will_not_fail_on_bad_json_async
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
assert result == "error_ModelBehaviorError"
@function_tool(strict_mode=False)
def optional_param_function(a: int, b: Optional[int] = None) -> str:
if b is None:
return f"{a}_no_b"
return f"{a}_{b}"
@pytest.mark.asyncio
async def test_non_strict_mode_function():
tool = optional_param_function
assert tool.strict_json_schema is False, "strict_json_schema should be False"
assert tool.params_json_schema.get("required") == ["a"], "required should only be a"
input_data = {"a": 5}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_no_b"
input_data = {"a": 5, "b": 10}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_10"
@function_tool(strict_mode=False)
def all_optional_params_function(
x: int = 42,
y: str = "hello",
z: Optional[int] = None,
) -> str:
if z is None:
return f"{x}_{y}_no_z"
return f"{x}_{y}_{z}"
@pytest.mark.asyncio
async def test_all_optional_params_function():
tool = all_optional_params_function
assert tool.strict_json_schema is False, "strict_json_schema should be False"
assert tool.params_json_schema.get("required") is None, "required should be empty"
input_data: dict[str, Any] = {}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "42_hello_no_z"
input_data = {"x": 10, "y": "world"}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_no_z"
input_data = {"x": 10, "y": "world", "z": 99}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_99"

View file

@ -4,8 +4,9 @@ import pytest
from pydantic import BaseModel
from typing_extensions import TypedDict
from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError, _utils
from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError
from agents.agent_output import _WRAPPER_DICT_KEY
from agents.util import _json
def test_plain_text_output():
@ -77,7 +78,7 @@ def test_bad_json_raises_error(mocker):
output_schema = Runner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type"
mock_validate_json = mocker.patch.object(_utils, "validate_json")
mock_validate_json = mocker.patch.object(_json, "validate_json")
mock_validate_json.return_value = ["foo"]
with pytest.raises(ModelBehaviorError):
@ -111,3 +112,4 @@ def test_setting_strict_false_works():
output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False)
assert not output_wrapper.strict_json_schema
assert output_wrapper.json_schema() == Foo.model_json_schema()
assert output_wrapper.json_schema() == Foo.model_json_schema()

201
tests/test_pretty_print.py Normal file
View file

@ -0,0 +1,201 @@
import json
import pytest
from inline_snapshot import snapshot
from pydantic import BaseModel
from agents import Agent, Runner
from agents.agent_output import _WRAPPER_DICT_KEY
from agents.util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
from tests.fake_model import FakeModel
from .test_responses import get_final_output_message, get_text_message
@pytest.mark.asyncio
async def test_pretty_result():
model = FakeModel()
model.set_next_output([get_text_message("Hi there")])
agent = Agent(name="test_agent", model=model)
result = await Runner.run(agent, input="Hello")
assert pretty_print_result(result) == snapshot("""\
RunResult:
- Last agent: Agent(name="test_agent", ...)
- Final output (str):
Hi there
- 1 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResult` for more details)\
""")
@pytest.mark.asyncio
async def test_pretty_run_result_streaming():
model = FakeModel()
model.set_next_output([get_text_message("Hi there")])
agent = Agent(name="test_agent", model=model)
result = Runner.run_streamed(agent, input="Hello")
async for _ in result.stream_events():
pass
assert pretty_print_run_result_streaming(result) == snapshot("""\
RunResultStreaming:
- Current agent: Agent(name="test_agent", ...)
- Current turn: 1
- Max turns: 10
- Is complete: True
- Final output (str):
Hi there
- 1 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResultStreaming` for more details)\
""")
class Foo(BaseModel):
bar: str
@pytest.mark.asyncio
async def test_pretty_run_result_structured_output():
model = FakeModel()
model.set_next_output(
[
get_text_message("Test"),
get_final_output_message(Foo(bar="Hi there").model_dump_json()),
]
)
agent = Agent(name="test_agent", model=model, output_type=Foo)
result = await Runner.run(agent, input="Hello")
assert pretty_print_result(result) == snapshot("""\
RunResult:
- Last agent: Agent(name="test_agent", ...)
- Final output (Foo):
{
"bar": "Hi there"
}
- 2 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResult` for more details)\
""")
@pytest.mark.asyncio
async def test_pretty_run_result_streaming_structured_output():
model = FakeModel()
model.set_next_output(
[
get_text_message("Test"),
get_final_output_message(Foo(bar="Hi there").model_dump_json()),
]
)
agent = Agent(name="test_agent", model=model, output_type=Foo)
result = Runner.run_streamed(agent, input="Hello")
async for _ in result.stream_events():
pass
assert pretty_print_run_result_streaming(result) == snapshot("""\
RunResultStreaming:
- Current agent: Agent(name="test_agent", ...)
- Current turn: 1
- Max turns: 10
- Is complete: True
- Final output (Foo):
{
"bar": "Hi there"
}
- 2 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResultStreaming` for more details)\
""")
@pytest.mark.asyncio
async def test_pretty_run_result_list_structured_output():
model = FakeModel()
model.set_next_output(
[
get_text_message("Test"),
get_final_output_message(
json.dumps(
{
_WRAPPER_DICT_KEY: [
Foo(bar="Hi there").model_dump(),
Foo(bar="Hi there 2").model_dump(),
]
}
)
),
]
)
agent = Agent(name="test_agent", model=model, output_type=list[Foo])
result = await Runner.run(agent, input="Hello")
assert pretty_print_result(result) == snapshot("""\
RunResult:
- Last agent: Agent(name="test_agent", ...)
- Final output (list):
[Foo(bar='Hi there'), Foo(bar='Hi there 2')]
- 2 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResult` for more details)\
""")
@pytest.mark.asyncio
async def test_pretty_run_result_streaming_list_structured_output():
model = FakeModel()
model.set_next_output(
[
get_text_message("Test"),
get_final_output_message(
json.dumps(
{
_WRAPPER_DICT_KEY: [
Foo(bar="Test").model_dump(),
Foo(bar="Test 2").model_dump(),
]
}
)
),
]
)
agent = Agent(name="test_agent", model=model, output_type=list[Foo])
result = Runner.run_streamed(agent, input="Hello")
async for _ in result.stream_events():
pass
assert pretty_print_run_result_streaming(result) == snapshot("""\
RunResultStreaming:
- Current agent: Agent(name="test_agent", ...)
- Current turn: 1
- Max turns: 10
- Is complete: True
- Final output (list):
[Foo(bar='Test'), Foo(bar='Test 2')]
- 2 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResultStreaming` for more details)\
""")

View file

@ -1,4 +1,5 @@
import pytest
from inline_snapshot import snapshot
from openai import AsyncOpenAI
from openai.types.responses import ResponseCompletedEvent
@ -6,7 +7,7 @@ from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace
from agents.tracing.span_data import ResponseSpanData
from tests import fake_model
from .testing_processor import fetch_ordered_spans
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans
class DummyTracing:
@ -54,6 +55,15 @@ async def test_get_response_creates_trace(monkeypatch):
"instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED
)
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"children": [{"type": "response", "data": {"response_id": "dummy-id"}}],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
@ -82,6 +92,10 @@ async def test_non_data_tracing_doesnt_set_response_id(monkeypatch):
"instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED_WITHOUT_DATA
)
assert fetch_normalized_spans() == snapshot(
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert spans[0].span_data.response is None
@ -107,6 +121,8 @@ async def test_disable_tracing_does_not_create_span(monkeypatch):
"instr", "input", ModelSettings(), [], None, [], ModelTracing.DISABLED
)
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
spans = fetch_ordered_spans()
assert len(spans) == 0
@ -139,6 +155,15 @@ async def test_stream_response_creates_trace(monkeypatch):
):
pass
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "test",
"children": [{"type": "response", "data": {"response_id": "dummy-id-123"}}],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
@ -174,6 +199,10 @@ async def test_stream_non_data_tracing_doesnt_set_response_id(monkeypatch):
):
pass
assert fetch_normalized_spans() == snapshot(
[{"workflow_name": "test", "children": [{"type": "response"}]}]
)
spans = fetch_ordered_spans()
assert len(spans) == 1
assert isinstance(spans[0].span_data, ResponseSpanData)
@ -208,5 +237,7 @@ async def test_stream_disabled_tracing_doesnt_create_span(monkeypatch):
):
pass
assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}])
spans = fetch_ordered_spans()
assert len(spans) == 0

View file

@ -4,6 +4,7 @@ import json
from typing import Any
import pytest
from inline_snapshot import snapshot
from typing_extensions import TypedDict
from agents import (
@ -27,7 +28,7 @@ from .test_responses import (
get_handoff_tool_call,
get_text_message,
)
from .testing_processor import fetch_ordered_spans, fetch_traces
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
@pytest.mark.asyncio
@ -45,6 +46,34 @@ async def test_single_turn_model_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
"children": [
{
"type": "generation",
"error": {
"message": "Error",
"data": {"name": "ValueError", "message": "test error"},
},
}
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"
@ -80,6 +109,43 @@ async def test_multi_turn_no_handoffs():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": ["foo"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "foo",
"input": '{"a": "b"}',
"output": "tool_result",
},
},
{
"type": "generation",
"error": {
"message": "Error",
"data": {"name": "ValueError", "message": "test error"},
},
},
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 4, (
f"should have agent, generation, tool, generation, got {len(spans)} with data: "
@ -110,6 +176,39 @@ async def test_tool_call_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent",
"handoffs": [],
"tools": ["foo"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"error": {
"message": "Error running tool",
"data": {
"tool_name": "foo",
"error": "Invalid JSON input for tool foo: bad_json",
},
},
"data": {"name": "foo", "input": "bad_json"},
},
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 3, (
f"should have agent, generation, tool spans, got {len(spans)} with data: "
@ -159,6 +258,43 @@ async def test_multiple_handoff_doesnt_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test",
"handoffs": ["test", "test"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{"type": "handoff", "data": {"from_agent": "test", "to_agent": "test"}},
],
},
{
"type": "agent",
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 7, (
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
@ -193,6 +329,21 @@ async def test_multiple_final_output_doesnt_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "Foo"},
"children": [{"type": "generation"}],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 generation, got {len(spans)} with data: "
@ -251,6 +402,76 @@ async def test_handoffs_lead_to_correct_agent_spans():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": ["test_agent_3"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 12, (
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
@ -285,6 +506,38 @@ async def test_max_turns_exceeded():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {"message": "Max turns exceeded", "data": {"max_turns": 2}},
"data": {
"name": "test",
"handoffs": [],
"tools": ["foo"],
"output_type": "Foo",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {"name": "foo", "input": "", "output": "result"},
},
{"type": "generation"},
{
"type": "function",
"data": {"name": "foo", "input": "", "output": "result"},
},
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 5, (
f"should have 1 agent span, 2 generations, 2 function calls, got "
@ -318,6 +571,30 @@ async def test_guardrail_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {
"message": "Guardrail tripwire triggered",
"data": {"guardrail": "guardrail_function"},
},
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [
{
"type": "guardrail",
"data": {"name": "guardrail_function", "triggered": True},
}
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "

View file

@ -5,6 +5,7 @@ import json
from typing import Any
import pytest
from inline_snapshot import snapshot
from typing_extensions import TypedDict
from agents import (
@ -32,7 +33,7 @@ from .test_responses import (
get_handoff_tool_call,
get_text_message,
)
from .testing_processor import fetch_ordered_spans, fetch_traces
from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces
@pytest.mark.asyncio
@ -52,6 +53,35 @@ async def test_single_turn_model_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {"message": "Error in agent run", "data": {"error": "test error"}},
"data": {
"name": "test_agent",
"handoffs": [],
"tools": [],
"output_type": "str",
},
"children": [
{
"type": "generation",
"error": {
"message": "Error",
"data": {"name": "ValueError", "message": "test error"},
},
}
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}"
@ -89,6 +119,44 @@ async def test_multi_turn_no_handoffs():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {"message": "Error in agent run", "data": {"error": "test error"}},
"data": {
"name": "test_agent",
"handoffs": [],
"tools": ["foo"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "foo",
"input": '{"a": "b"}',
"output": "tool_result",
},
},
{
"type": "generation",
"error": {
"message": "Error",
"data": {"name": "ValueError", "message": "test error"},
},
},
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 4, (
f"should have agent, generation, tool, generation, got {len(spans)} with data: "
@ -121,6 +189,43 @@ async def test_tool_call_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {
"message": "Error in agent run",
"data": {"error": "Invalid JSON input for tool foo: bad_json"},
},
"data": {
"name": "test_agent",
"handoffs": [],
"tools": ["foo"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"error": {
"message": "Error running tool",
"data": {
"tool_name": "foo",
"error": "Invalid JSON input for tool foo: bad_json",
},
},
"data": {"name": "foo", "input": "bad_json"},
},
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 3, (
f"should have agent, generation, tool spans, got {len(spans)} with data: "
@ -173,6 +278,43 @@ async def test_multiple_handoff_doesnt_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test",
"handoffs": ["test", "test"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{"type": "handoff", "data": {"from_agent": "test", "to_agent": "test"}},
],
},
{
"type": "agent",
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 7, (
f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: "
@ -211,6 +353,21 @@ async def test_multiple_final_output_no_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "Foo"},
"children": [{"type": "generation"}],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 generation, got {len(spans)} with data: "
@ -271,12 +428,152 @@ async def test_handoffs_lead_to_correct_agent_spans():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": ["test_agent_3"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [{"type": "generation"}],
},
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 12, (
f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: "
f"{[x.span_data for x in spans]}"
)
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_1",
"handoffs": ["test_agent_3"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {
"name": "some_function",
"input": '{"a": "b"}',
"output": "result",
},
},
{"type": "generation"},
{
"type": "handoff",
"data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"},
},
],
},
{
"type": "agent",
"data": {
"name": "test_agent_3",
"handoffs": ["test_agent_1", "test_agent_2"],
"tools": ["some_function"],
"output_type": "str",
},
"children": [{"type": "generation"}],
},
],
}
]
)
@pytest.mark.asyncio
async def test_max_turns_exceeded():
@ -307,6 +604,38 @@ async def test_max_turns_exceeded():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {"message": "Max turns exceeded", "data": {"max_turns": 2}},
"data": {
"name": "test",
"handoffs": [],
"tools": ["foo"],
"output_type": "Foo",
},
"children": [
{"type": "generation"},
{
"type": "function",
"data": {"name": "foo", "input": "", "output": "result"},
},
{"type": "generation"},
{
"type": "function",
"data": {"name": "foo", "input": "", "output": "result"},
},
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 5, (
f"should have 1 agent, 2 generations, 2 function calls, got "
@ -347,6 +676,33 @@ async def test_input_guardrail_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {
"message": "Guardrail tripwire triggered",
"data": {
"guardrail": "input_guardrail_function",
"type": "input_guardrail",
},
},
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [
{
"type": "guardrail",
"data": {"name": "input_guardrail_function", "triggered": True},
}
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "
@ -387,6 +743,30 @@ async def test_output_guardrail_error():
traces = fetch_traces()
assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}"
assert fetch_normalized_spans() == snapshot(
[
{
"workflow_name": "Agent workflow",
"children": [
{
"type": "agent",
"error": {
"message": "Guardrail tripwire triggered",
"data": {"guardrail": "output_guardrail_function"},
},
"data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"},
"children": [
{
"type": "guardrail",
"data": {"name": "output_guardrail_function", "triggered": True},
}
],
}
],
}
]
)
spans = fetch_ordered_spans()
assert len(spans) == 2, (
f"should have 1 agent, 1 guardrail, got {len(spans)} with data: "

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import threading
from datetime import datetime
from typing import Any, Literal
from agents.tracing import Span, Trace, TracingProcessor
@ -77,3 +78,37 @@ def fetch_traces() -> list[Trace]:
def fetch_events() -> list[TestSpanProcessorEvent]:
return SPAN_PROCESSOR_TESTING._events
def fetch_normalized_spans():
nodes: dict[tuple[str, str | None], dict[str, Any]] = {}
traces = []
for trace_obj in fetch_traces():
trace = trace_obj.export()
assert trace
assert trace.pop("object") == "trace"
assert trace.pop("id").startswith("trace_")
trace = {k: v for k, v in trace.items() if v is not None}
nodes[(trace_obj.trace_id, None)] = trace
traces.append(trace)
if not traces:
assert not fetch_ordered_spans()
for span_obj in fetch_ordered_spans():
span = span_obj.export()
assert span
assert span.pop("object") == "trace.span"
assert span.pop("id").startswith("span_")
assert datetime.fromisoformat(span.pop("started_at"))
assert datetime.fromisoformat(span.pop("ended_at"))
parent_id = span.pop("parent_id")
assert "type" not in span
span_data = span.pop("span_data")
span = {"type": span_data.pop("type")} | {k: v for k, v in span.items() if v is not None}
span_data = {k: v for k, v in span_data.items() if v is not None}
if span_data:
span["data"] = span_data
nodes[(span_obj.trace_id, span_obj.span_id)] = span
nodes[(span.pop("trace_id"), parent_id)].setdefault("children", []).append(span)
return traces

36
uv.lock
View file

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.9"
[[package]]
@ -25,6 +26,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 },
]
[[package]]
name = "asttokens"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 },
]
[[package]]
name = "babel"
version = "2.17.0"
@ -239,6 +249,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 },
]
[[package]]
name = "executing"
version = "2.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 },
]
[[package]]
name = "ghp-import"
version = "2.1.0"
@ -391,6 +410,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
]
[[package]]
name = "inline-snapshot"
version = "0.20.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "asttokens" },
{ name = "executing" },
{ name = "rich" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b0/41/9bd2ecd10ef789e8aff6fb68dcc7677dc31b33b2d27c306c0d40fc982fbc/inline_snapshot-0.20.7.tar.gz", hash = "sha256:d55bbb6254d0727dc304729ca7998cde1c1e984c4bf50281514aa9d727a56cf2", size = 92643 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/01/8f/1bf23da63ad1a0b14ca2d9114700123ef76732e375548f4f9ca94052817e/inline_snapshot-0.20.7-py3-none-any.whl", hash = "sha256:2df6dd8710d1f0def2c1f9d6c25fd03d7beba01f3addf52fc370343d9ee9959f", size = 48108 },
]
[[package]]
name = "jinja2"
version = "3.1.6"
@ -796,6 +830,7 @@ dependencies = [
[package.dev-dependencies]
dev = [
{ name = "coverage" },
{ name = "inline-snapshot" },
{ name = "mkdocs" },
{ name = "mkdocs-material" },
{ name = "mkdocstrings", extra = ["python"] },
@ -821,6 +856,7 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
{ name = "coverage", specifier = ">=7.6.12" },
{ name = "inline-snapshot", specifier = ">=0.20.7" },
{ name = "mkdocs", specifier = ">=1.6.0" },
{ name = "mkdocs-material", specifier = ">=9.6.0" },
{ name = "mkdocstrings", extras = ["python"], specifier = ">=0.28.0" },