Merge branch 'main' of github.com:openai/openai-agents-python into alex/inline-snapshot

This commit is contained in:
Alex Hall 2025-03-17 23:55:56 +02:00
commit 2d2e8f0e34
33 changed files with 494 additions and 119 deletions

View file

@ -18,6 +18,14 @@ mypy:
tests:
uv run pytest
.PHONY: snapshots-fix
snapshots-fix:
uv run pytest --inline-snapshot=fix
.PHONY: snapshots-create
snapshots-create:
uv run pytest --inline-snapshot=create
.PHONY: old_version_tests
old_version_tests:
UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest

View file

@ -142,7 +142,7 @@ The Agents SDK is designed to be highly flexible, allowing you to model a wide r
## Tracing
The Agents SDK automatically traces your agent runs, making it easy to track and debug the behavior of your agents. Tracing is extensible by design, supporting custom spans and a wide variety of external destinations, including [Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents), [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk), [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk), [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration), and [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent). For more details about how to customize or disable tracing, see [Tracing](http://openai.github.io/openai-agents-python/tracing).
The Agents SDK automatically traces your agent runs, making it easy to track and debug the behavior of your agents. Tracing is extensible by design, supporting custom spans and a wide variety of external destinations, including [Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents), [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk), [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk), [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration), and [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent). For more details about how to customize or disable tracing, see [Tracing](http://openai.github.io/openai-agents-python/tracing), which also includes a larger list of [external tracing processors](http://openai.github.io/openai-agents-python/tracing/#external-tracing-processors-list).
## Development (only needed if you need to edit the SDK/examples)

View file

@ -9,6 +9,8 @@ The Agents SDK includes built-in tracing, collecting a comprehensive record of e
1. You can globally disable tracing by setting the env var `OPENAI_AGENTS_DISABLE_TRACING=1`
2. You can disable tracing for a single run by setting [`agents.run.RunConfig.tracing_disabled`][] to `True`
***For organizations operating under a Zero Data Retention (ZDR) policy using OpenAI's APIs, tracing is unavailable.***
## Traces and spans
- **Traces** represent a single end-to-end operation of a "workflow". They're composed of Spans. Traces have the following properties:
@ -88,10 +90,12 @@ To customize this default setup, to send traces to alternative or additional bac
1. [`add_trace_processor()`][agents.tracing.add_trace_processor] lets you add an **additional** trace processor that will receive traces and spans as they are ready. This lets you do your own processing in addition to sending traces to OpenAI's backend.
2. [`set_trace_processors()`][agents.tracing.set_trace_processors] lets you **replace** the default processors with your own trace processors. This means traces will not be sent to the OpenAI backend unless you include a `TracingProcessor` that does so.
External trace processors include:
## External tracing processors list
- [Arize-Phoenix](https://docs.arize.com/phoenix/tracing/integrations-tracing/openai-agents-sdk)
- [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk)
- [Pydantic Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents)
- [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk)
- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration))
- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration)
- [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent)
- [LangSmith](https://docs.smith.langchain.com/observability/how_to_guides/trace_with_openai_agents_sdk)

View file

@ -30,8 +30,8 @@ If the guardrail trips, we'll respond with a refusal message.
### 1. An agent-based guardrail that is triggered if the user is asking to do math homework
class MathHomeworkOutput(BaseModel):
is_math_homework: bool
reasoning: str
is_math_homework: bool
guardrail_agent = Agent(

View file

@ -23,8 +23,8 @@ story_outline_generator = Agent(
@dataclass
class EvaluationFeedback:
score: Literal["pass", "needs_improvement", "fail"]
feedback: str
score: Literal["pass", "needs_improvement", "fail"]
evaluator = Agent[None](

View file

@ -74,7 +74,7 @@ multiply_agent = Agent(
start_agent = Agent(
name="Start Agent",
instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multipler agent.",
instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multiply agent.",
tools=[random_number],
output_type=FinalResult,
handoffs=[multiply_agent],

View file

@ -3,7 +3,7 @@ from agents import Agent, Runner
agent = Agent(name="Assistant", instructions="You are a helpful assistant")
# Intended for Jupyter notebooks where there's an existing event loop
result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704
result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704
print(result.final_output)
# Code within code loops,

View file

@ -60,9 +60,9 @@ async def main():
print("Step 1 done")
# 2. Ask it to square a number
# 2. Ask it to generate a number
result = await Runner.run(
second_agent,
first_agent,
input=result.to_input_list()
+ [{"content": "Can you generate a random number between 0 and 100?", "role": "user"}],
)

View file

@ -60,9 +60,9 @@ async def main():
print("Step 1 done")
# 2. Ask it to square a number
# 2. Ask it to generate a number
result = await Runner.run(
second_agent,
first_agent,
input=result.to_input_list()
+ [{"content": "Can you generate a random number between 0 and 100?", "role": "user"}],
)

View file

@ -47,7 +47,7 @@ dev = [
"mkdocstrings[python]>=0.28.0",
"coverage>=7.6.12",
"playwright==1.50.0",
"inline-snapshot>=0.20.5",
"inline-snapshot>=0.20.7",
]
[tool.uv.workspace]
members = ["agents"]
@ -118,3 +118,6 @@ filterwarnings = [
markers = [
"allow_call_model_methods: mark test as allowing calls to real model implementations",
]
[tool.inline-snapshot]
format-command="ruff format --stdin-filename {filename}"

View file

@ -73,6 +73,7 @@ from .tracing import (
SpanData,
SpanError,
Trace,
TracingProcessor,
add_trace_processor,
agent_span,
custom_span,
@ -208,6 +209,7 @@ __all__ = [
"set_tracing_disabled",
"trace",
"Trace",
"TracingProcessor",
"SpanError",
"Span",
"SpanData",

View file

@ -25,7 +25,6 @@ from openai.types.responses.response_computer_tool_call import (
from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from . import _utils
from .agent import Agent
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Computer
@ -59,6 +58,7 @@ from .tracing import (
handoff_span,
trace,
)
from .util import _coro, _error_tracing
if TYPE_CHECKING:
from .run import RunConfig
@ -293,7 +293,7 @@ class RunImpl:
elif isinstance(output, ResponseComputerToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
if not computer_tool:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Computer tool not found",
data={},
@ -324,7 +324,7 @@ class RunImpl:
# Regular function tool call
else:
if output.name not in function_map:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Tool not found",
data={"tool_name": output.name},
@ -368,7 +368,7 @@ class RunImpl:
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
)
@ -378,11 +378,11 @@ class RunImpl:
(
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)
except Exception as e:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": func_tool.name, "error": str(e)},
@ -502,7 +502,7 @@ class RunImpl:
source=agent,
)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)
@ -520,7 +520,7 @@ class RunImpl:
new_items=tuple(new_step_items),
)
if not callable(input_filter):
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
span_handoff,
SpanError(
message="Invalid input filter",
@ -530,7 +530,7 @@ class RunImpl:
raise UserError(f"Invalid input filter: {input_filter}")
filtered = input_filter(handoff_input_data)
if not isinstance(filtered, HandoffInputData):
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
span_handoff,
SpanError(
message="Invalid input filter result",
@ -591,7 +591,7 @@ class RunImpl:
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _utils.noop_coroutine(),
else _coro.noop_coroutine(),
)
@classmethod
@ -706,7 +706,7 @@ class ComputerAction:
(
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
output_func,
)
@ -716,7 +716,7 @@ class ComputerAction:
(
agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)

View file

@ -1,61 +0,0 @@
from __future__ import annotations
import re
from collections.abc import Awaitable
from typing import Any, Literal, Union
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeVar
from .exceptions import ModelBehaviorError
from .logger import logger
from .tracing import Span, SpanError, get_current_span
T = TypeVar("T")
MaybeAwaitable = Union[Awaitable[T], T]
def transform_string_function_style(name: str) -> str:
# Replace spaces with underscores
name = name.replace(" ", "_")
# Replace non-alphanumeric characters with underscores
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
return name.lower()
def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T:
partial_setting: bool | Literal["off", "on", "trailing-strings"] = (
"trailing-strings" if partial else False
)
try:
validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting)
return validated
except ValidationError as e:
attach_error_to_current_span(
SpanError(
message="Invalid JSON provided",
data={},
)
)
raise ModelBehaviorError(
f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}"
) from e
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
span.set_error(error)
def attach_error_to_current_span(error: SpanError) -> None:
span = get_current_span()
if span:
attach_error_to_span(span, error)
else:
logger.warning(f"No span to add error {error} to")
async def noop_coroutine() -> None:
pass

View file

@ -6,8 +6,6 @@ from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
from . import _utils
from ._utils import MaybeAwaitable
from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
from .items import ItemHelpers
@ -16,6 +14,8 @@ from .model_settings import ModelSettings
from .models.interface import Model
from .run_context import RunContextWrapper, TContext
from .tool import Tool, function_tool
from .util import _transforms
from .util._types import MaybeAwaitable
if TYPE_CHECKING:
from .lifecycle import AgentHooks
@ -27,8 +27,8 @@ class Agent(Generic[TContext]):
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
addition, you can pass `description`, which is a human-readable description of the agent, used
when the agent is used inside tools/handoffs.
addition, you can pass `handoff_description`, which is a human-readable description of the
agent, used when the agent is used inside tools/handoffs.
Agents are generic on the context type. The context is a (mutable) object you create. It is
passed to tool functions, handoffs, guardrails, etc.
@ -126,7 +126,7 @@ class Agent(Generic[TContext]):
"""
@function_tool(
name_override=tool_name or _utils.transform_string_function_style(self.name),
name_override=tool_name or _transforms.transform_string_function_style(self.name),
description_override=tool_description or "",
)
async def run_agent(context: RunContextWrapper, input: str) -> str:

View file

@ -4,10 +4,10 @@ from typing import Any
from pydantic import BaseModel, TypeAdapter
from typing_extensions import TypedDict, get_args, get_origin
from . import _utils
from .exceptions import ModelBehaviorError, UserError
from .strict_schema import ensure_strict_json_schema
from .tracing import SpanError
from .util import _error_tracing, _json
_WRAPPER_DICT_KEY = "response"
@ -87,10 +87,10 @@ class AgentOutputSchema:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _utils.validate_json(json_str, self._type_adapter, partial)
validated = _json.validate_json(json_str, self._type_adapter, partial)
if self._is_wrapped:
if not isinstance(validated, dict):
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Expected a dict, got {type(validated)}"},
@ -101,7 +101,7 @@ class AgentOutputSchema:
)
if _WRAPPER_DICT_KEY not in validated:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},

View file

@ -33,6 +33,9 @@ class FuncSchema:
"""The signature of the function."""
takes_context: bool = False
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
@ -337,4 +340,5 @@ def function_schema(
params_json_schema=json_schema,
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
)

View file

@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload
from typing_extensions import TypeVar
from ._utils import MaybeAwaitable
from .exceptions import UserError
from .items import TResponseInputItem
from .run_context import RunContextWrapper, TContext
from .util._types import MaybeAwaitable
if TYPE_CHECKING:
from .agent import Agent

View file

@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
from pydantic import TypeAdapter
from typing_extensions import TypeAlias, TypeVar
from . import _utils
from .exceptions import ModelBehaviorError, UserError
from .items import RunItem, TResponseInputItem
from .run_context import RunContextWrapper, TContext
from .strict_schema import ensure_strict_json_schema
from .tracing.spans import SpanError
from .util import _error_tracing, _json, _transforms
if TYPE_CHECKING:
from .agent import Agent
@ -104,7 +104,7 @@ class Handoff(Generic[TContext]):
@classmethod
def default_tool_name(cls, agent: Agent[Any]) -> str:
return _utils.transform_string_function_style(f"transfer_to_{agent.name}")
return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")
@classmethod
def default_tool_description(cls, agent: Agent[Any]) -> str:
@ -192,7 +192,7 @@ def handoff(
) -> Agent[Any]:
if input_type is not None and type_adapter is not None:
if input_json is None:
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Handoff function expected non-null input, but got None",
data={"details": "input_json is None"},
@ -200,7 +200,7 @@ def handoff(
)
raise ModelBehaviorError("Handoff function expected non-null input, but got None")
validated_input = _utils.validate_json(
validated_input = _json.validate_json(
json_str=input_json,
type_adapter=type_adapter,
partial=False,

View file

@ -17,6 +17,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
from .stream_events import StreamEvent
from .tracing import Trace
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
if TYPE_CHECKING:
from ._run_impl import QueueCompleteSentinel
@ -89,6 +90,9 @@ class RunResult(RunResultBase):
"""The last agent that was run."""
return self._last_agent
def __str__(self) -> str:
return pretty_print_result(self)
@dataclass
class RunResultStreaming(RunResultBase):
@ -216,3 +220,6 @@ class RunResultStreaming(RunResultBase):
if self._output_guardrails_task and not self._output_guardrails_task.done():
self._output_guardrails_task.cancel()
def __str__(self) -> str:
return pretty_print_run_result_streaming(self)

View file

@ -7,7 +7,6 @@ from typing import Any, cast
from openai.types.responses import ResponseCompletedEvent
from . import Model, _utils
from ._run_impl import (
NextStepFinalOutput,
NextStepHandoff,
@ -33,7 +32,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .lifecycle import RunHooks
from .logger import logger
from .model_settings import ModelSettings
from .models.interface import ModelProvider
from .models.interface import Model, ModelProvider
from .models.openai_provider import OpenAIProvider
from .result import RunResult, RunResultStreaming
from .run_context import RunContextWrapper, TContext
@ -41,6 +40,7 @@ from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
from .tracing.span_data import AgentSpanData
from .usage import Usage
from .util import _coro, _error_tracing
DEFAULT_MAX_TURNS = 10
@ -193,7 +193,7 @@ class Runner:
current_turn += 1
if current_turn > max_turns:
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
current_span,
SpanError(
message="Max turns exceeded",
@ -447,7 +447,7 @@ class Runner:
for done in asyncio.as_completed(guardrail_tasks):
result = await done
if result.output.tripwire_triggered:
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
parent_span,
SpanError(
message="Guardrail tripwire triggered",
@ -511,7 +511,7 @@ class Runner:
streamed_result.current_turn = current_turn
if current_turn > max_turns:
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
current_span,
SpanError(
message="Max turns exceeded",
@ -583,7 +583,7 @@ class Runner:
pass
except Exception as e:
if current_span:
_utils.attach_error_to_span(
_error_tracing.attach_error_to_span(
current_span,
SpanError(
message="Error in agent run",
@ -615,7 +615,7 @@ class Runner:
(
agent.hooks.on_start(context_wrapper, agent)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)
@ -705,7 +705,7 @@ class Runner:
(
agent.hooks.on_start(context_wrapper, agent)
if agent.hooks
else _utils.noop_coroutine()
else _coro.noop_coroutine()
),
)
@ -796,7 +796,7 @@ class Runner:
# Cancel all guardrail tasks if a tripwire is triggered.
for t in guardrail_tasks:
t.cancel()
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Guardrail tripwire triggered",
data={"guardrail": result.guardrail.get_name()},
@ -834,7 +834,7 @@ class Runner:
# Cancel all guardrail tasks if a tripwire is triggered.
for t in guardrail_tasks:
t.cancel()
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Guardrail tripwire triggered",
data={"guardrail": result.guardrail.get_name()},

View file

@ -11,14 +11,15 @@ from openai.types.responses.web_search_tool_param import UserLocation
from pydantic import ValidationError
from typing_extensions import Concatenate, ParamSpec
from . import _debug, _utils
from ._utils import MaybeAwaitable
from . import _debug
from .computer import AsyncComputer, Computer
from .exceptions import ModelBehaviorError
from .function_schema import DocstringStyle, function_schema
from .logger import logger
from .run_context import RunContextWrapper
from .tracing import SpanError
from .util import _error_tracing
from .util._types import MaybeAwaitable
ToolParams = ParamSpec("ToolParams")
@ -137,6 +138,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
@ -150,6 +152,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
@ -163,6 +166,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
@ -186,6 +190,8 @@ def function_tool(
failure_error_function: If provided, use this function to generate an error message when
the tool call fails. The error message is sent to the LLM. If you pass None, then no
error message will be sent and instead an Exception will be raised.
strict_mode: If False, parameters with default values become optional in the
function schema.
"""
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@ -195,6 +201,7 @@ def function_tool(
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
)
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
@ -257,7 +264,7 @@ def function_tool(
if inspect.isawaitable(result):
return await result
_utils.attach_error_to_current_span(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool (non-fatal)",
data={
@ -273,6 +280,7 @@ def function_tool(
description=schema.description or "",
params_json_schema=schema.params_json_schema,
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
)
# If func is actually a callable, we were used as @function_tool with no parentheses

View file

2
src/agents/util/_coro.py Normal file
View file

@ -0,0 +1,2 @@
async def noop_coroutine() -> None:
pass

View file

@ -0,0 +1,16 @@
from typing import Any
from ..logger import logger
from ..tracing import Span, SpanError, get_current_span
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
span.set_error(error)
def attach_error_to_current_span(error: SpanError) -> None:
span = get_current_span()
if span:
attach_error_to_span(span, error)
else:
logger.warning(f"No span to add error {error} to")

31
src/agents/util/_json.py Normal file
View file

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Literal
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeVar
from ..exceptions import ModelBehaviorError
from ..tracing import SpanError
from ._error_tracing import attach_error_to_current_span
T = TypeVar("T")
def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T:
partial_setting: bool | Literal["off", "on", "trailing-strings"] = (
"trailing-strings" if partial else False
)
try:
validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting)
return validated
except ValidationError as e:
attach_error_to_current_span(
SpanError(
message="Invalid JSON provided",
data={},
)
)
raise ModelBehaviorError(
f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}"
) from e

View file

@ -0,0 +1,56 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from ..result import RunResult, RunResultBase, RunResultStreaming
def _indent(text: str, indent_level: int) -> str:
indent_string = " " * indent_level
return "\n".join(f"{indent_string}{line}" for line in text.splitlines())
def _final_output_str(result: "RunResultBase") -> str:
if result.final_output is None:
return "None"
elif isinstance(result.final_output, str):
return result.final_output
elif isinstance(result.final_output, BaseModel):
return result.final_output.model_dump_json(indent=2)
else:
return str(result.final_output)
def pretty_print_result(result: "RunResult") -> str:
output = "RunResult:"
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
output += (
f"\n- Final output ({type(result.final_output).__name__}):\n"
f"{_indent(_final_output_str(result), 2)}"
)
output += f"\n- {len(result.new_items)} new item(s)"
output += f"\n- {len(result.raw_responses)} raw response(s)"
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
output += f"\n- {len(result.output_guardrail_results)} output guardrail result(s)"
output += "\n(See `RunResult` for more details)"
return output
def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
output = "RunResultStreaming:"
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'
output += f"\n- Current turn: {result.current_turn}"
output += f"\n- Max turns: {result.max_turns}"
output += f"\n- Is complete: {result.is_complete}"
output += (
f"\n- Final output ({type(result.final_output).__name__}):\n"
f"{_indent(_final_output_str(result), 2)}"
)
output += f"\n- {len(result.new_items)} new item(s)"
output += f"\n- {len(result.raw_responses)} raw response(s)"
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
output += f"\n- {len(result.output_guardrail_results)} output guardrail result(s)"
output += "\n(See `RunResultStreaming` for more details)"
return output

View file

@ -0,0 +1,11 @@
import re
def transform_string_function_style(name: str) -> str:
# Replace spaces with underscores
name = name.replace(" ", "_")
# Replace non-alphanumeric characters with underscores
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
return name.lower()

View file

@ -0,0 +1,7 @@
from collections.abc import Awaitable
from typing import Union
from typing_extensions import TypeVar
T = TypeVar("T")
MaybeAwaitable = Union[Awaitable[T], T]

25
tests/README.md Normal file
View file

@ -0,0 +1,25 @@
# Tests
Before running any tests, make sure you have `uv` installed (and ideally run `make sync` after).
## Running tests
```
make tests
```
## Snapshots
We use [inline-snapshots](https://15r10nk.github.io/inline-snapshot/latest/) for some tests. If your code adds new snapshot tests or breaks existing ones, you can fix/create them. After fixing/creating snapshots, run `make tests` again to verify the tests pass.
### Fixing snapshots
```
make snapshots-fix
```
### Creating snapshots
```
make snapshots-update
```

View file

@ -1,6 +1,6 @@
import asyncio
import json
from typing import Any
from typing import Any, Optional
import pytest
@ -142,3 +142,51 @@ async def test_no_error_on_invalid_json_async():
tool = will_not_fail_on_bad_json_async
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
assert result == "error_ModelBehaviorError"
@function_tool(strict_mode=False)
def optional_param_function(a: int, b: Optional[int] = None) -> str:
if b is None:
return f"{a}_no_b"
return f"{a}_{b}"
@pytest.mark.asyncio
async def test_optional_param_function():
tool = optional_param_function
input_data = {"a": 5}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_no_b"
input_data = {"a": 5, "b": 10}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_10"
@function_tool(strict_mode=False)
def multiple_optional_params_function(
x: int = 42,
y: str = "hello",
z: Optional[int] = None,
) -> str:
if z is None:
return f"{x}_{y}_no_z"
return f"{x}_{y}_{z}"
@pytest.mark.asyncio
async def test_multiple_optional_params_function():
tool = multiple_optional_params_function
input_data: dict[str, Any] = {}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "42_hello_no_z"
input_data = {"x": 10, "y": "world"}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_no_z"
input_data = {"x": 10, "y": "world", "z": 99}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_99"

View file

@ -4,8 +4,9 @@ import pytest
from pydantic import BaseModel
from typing_extensions import TypedDict
from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError, _utils
from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError
from agents.agent_output import _WRAPPER_DICT_KEY
from agents.util import _json
def test_plain_text_output():
@ -77,7 +78,7 @@ def test_bad_json_raises_error(mocker):
output_schema = Runner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type"
mock_validate_json = mocker.patch.object(_utils, "validate_json")
mock_validate_json = mocker.patch.object(_json, "validate_json")
mock_validate_json.return_value = ["foo"]
with pytest.raises(ModelBehaviorError):
@ -111,3 +112,4 @@ def test_setting_strict_false_works():
output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False)
assert not output_wrapper.strict_json_schema
assert output_wrapper.json_schema() == Foo.model_json_schema()
assert output_wrapper.json_schema() == Foo.model_json_schema()

201
tests/test_pretty_print.py Normal file
View file

@ -0,0 +1,201 @@
import json
import pytest
from inline_snapshot import snapshot
from pydantic import BaseModel
from agents import Agent, Runner
from agents.agent_output import _WRAPPER_DICT_KEY
from agents.util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
from tests.fake_model import FakeModel
from .test_responses import get_final_output_message, get_text_message
@pytest.mark.asyncio
async def test_pretty_result():
model = FakeModel()
model.set_next_output([get_text_message("Hi there")])
agent = Agent(name="test_agent", model=model)
result = await Runner.run(agent, input="Hello")
assert pretty_print_result(result) == snapshot("""\
RunResult:
- Last agent: Agent(name="test_agent", ...)
- Final output (str):
Hi there
- 1 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResult` for more details)\
""")
@pytest.mark.asyncio
async def test_pretty_run_result_streaming():
model = FakeModel()
model.set_next_output([get_text_message("Hi there")])
agent = Agent(name="test_agent", model=model)
result = Runner.run_streamed(agent, input="Hello")
async for _ in result.stream_events():
pass
assert pretty_print_run_result_streaming(result) == snapshot("""\
RunResultStreaming:
- Current agent: Agent(name="test_agent", ...)
- Current turn: 1
- Max turns: 10
- Is complete: True
- Final output (str):
Hi there
- 1 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResultStreaming` for more details)\
""")
class Foo(BaseModel):
bar: str
@pytest.mark.asyncio
async def test_pretty_run_result_structured_output():
model = FakeModel()
model.set_next_output(
[
get_text_message("Test"),
get_final_output_message(Foo(bar="Hi there").model_dump_json()),
]
)
agent = Agent(name="test_agent", model=model, output_type=Foo)
result = await Runner.run(agent, input="Hello")
assert pretty_print_result(result) == snapshot("""\
RunResult:
- Last agent: Agent(name="test_agent", ...)
- Final output (Foo):
{
"bar": "Hi there"
}
- 2 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResult` for more details)\
""")
@pytest.mark.asyncio
async def test_pretty_run_result_streaming_structured_output():
model = FakeModel()
model.set_next_output(
[
get_text_message("Test"),
get_final_output_message(Foo(bar="Hi there").model_dump_json()),
]
)
agent = Agent(name="test_agent", model=model, output_type=Foo)
result = Runner.run_streamed(agent, input="Hello")
async for _ in result.stream_events():
pass
assert pretty_print_run_result_streaming(result) == snapshot("""\
RunResultStreaming:
- Current agent: Agent(name="test_agent", ...)
- Current turn: 1
- Max turns: 10
- Is complete: True
- Final output (Foo):
{
"bar": "Hi there"
}
- 2 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResultStreaming` for more details)\
""")
@pytest.mark.asyncio
async def test_pretty_run_result_list_structured_output():
model = FakeModel()
model.set_next_output(
[
get_text_message("Test"),
get_final_output_message(
json.dumps(
{
_WRAPPER_DICT_KEY: [
Foo(bar="Hi there").model_dump(),
Foo(bar="Hi there 2").model_dump(),
]
}
)
),
]
)
agent = Agent(name="test_agent", model=model, output_type=list[Foo])
result = await Runner.run(agent, input="Hello")
assert pretty_print_result(result) == snapshot("""\
RunResult:
- Last agent: Agent(name="test_agent", ...)
- Final output (list):
[Foo(bar='Hi there'), Foo(bar='Hi there 2')]
- 2 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResult` for more details)\
""")
@pytest.mark.asyncio
async def test_pretty_run_result_streaming_list_structured_output():
model = FakeModel()
model.set_next_output(
[
get_text_message("Test"),
get_final_output_message(
json.dumps(
{
_WRAPPER_DICT_KEY: [
Foo(bar="Test").model_dump(),
Foo(bar="Test 2").model_dump(),
]
}
)
),
]
)
agent = Agent(name="test_agent", model=model, output_type=list[Foo])
result = Runner.run_streamed(agent, input="Hello")
async for _ in result.stream_events():
pass
assert pretty_print_run_result_streaming(result) == snapshot("""\
RunResultStreaming:
- Current agent: Agent(name="test_agent", ...)
- Current turn: 1
- Max turns: 10
- Is complete: True
- Final output (list):
[Foo(bar='Test'), Foo(bar='Test 2')]
- 2 new item(s)
- 1 raw response(s)
- 0 input guardrail result(s)
- 0 output guardrail result(s)
(See `RunResultStreaming` for more details)\
""")

View file

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.9"
[[package]]
@ -411,7 +412,7 @@ wheels = [
[[package]]
name = "inline-snapshot"
version = "0.20.5"
version = "0.20.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "asttokens" },
@ -419,9 +420,9 @@ dependencies = [
{ name = "rich" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3b/95/9b85a63031c168dd1c479f8cfd5cae42d42d6ac41c18dd760a104bc87ddc/inline_snapshot-0.20.5.tar.gz", hash = "sha256:d8b67c6d533c0a3f566e72608144b54da65dc3da5d0dba4169b2c56b75530fb5", size = 92215 }
sdist = { url = "https://files.pythonhosted.org/packages/b0/41/9bd2ecd10ef789e8aff6fb68dcc7677dc31b33b2d27c306c0d40fc982fbc/inline_snapshot-0.20.7.tar.gz", hash = "sha256:d55bbb6254d0727dc304729ca7998cde1c1e984c4bf50281514aa9d727a56cf2", size = 92643 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/71/34e775bbf0bcf81d588d80a1df93437f937b0df9a841f246606a03fc5eff/inline_snapshot-0.20.5-py3-none-any.whl", hash = "sha256:3aa56acf5985d89f17ebd4df4aef00faacc49f10cdf4e6b42be701ffc9702b5a", size = 48071 },
{ url = "https://files.pythonhosted.org/packages/01/8f/1bf23da63ad1a0b14ca2d9114700123ef76732e375548f4f9ca94052817e/inline_snapshot-0.20.7-py3-none-any.whl", hash = "sha256:2df6dd8710d1f0def2c1f9d6c25fd03d7beba01f3addf52fc370343d9ee9959f", size = 48108 },
]
[[package]]
@ -855,7 +856,7 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
{ name = "coverage", specifier = ">=7.6.12" },
{ name = "inline-snapshot", specifier = ">=0.20.5" },
{ name = "inline-snapshot", specifier = ">=0.20.7" },
{ name = "mkdocs", specifier = ">=1.6.0" },
{ name = "mkdocs-material", specifier = ">=9.6.0" },
{ name = "mkdocstrings", extras = ["python"], specifier = ">=0.28.0" },