diff --git a/Makefile b/Makefile index 7dd9bbd..39899d8 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,14 @@ mypy: tests: uv run pytest +.PHONY: snapshots-fix +snapshots-fix: + uv run pytest --inline-snapshot=fix + +.PHONY: snapshots-create +snapshots-create: + uv run pytest --inline-snapshot=create + .PHONY: old_version_tests old_version_tests: UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest diff --git a/README.md b/README.md index 210f6f4..51ca3c6 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,7 @@ The Agents SDK is designed to be highly flexible, allowing you to model a wide r ## Tracing -The Agents SDK automatically traces your agent runs, making it easy to track and debug the behavior of your agents. Tracing is extensible by design, supporting custom spans and a wide variety of external destinations, including [Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents), [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk), [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk), [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration), and [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent). For more details about how to customize or disable tracing, see [Tracing](http://openai.github.io/openai-agents-python/tracing). +The Agents SDK automatically traces your agent runs, making it easy to track and debug the behavior of your agents. Tracing is extensible by design, supporting custom spans and a wide variety of external destinations, including [Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents), [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk), [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk), [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration), and [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent). For more details about how to customize or disable tracing, see [Tracing](http://openai.github.io/openai-agents-python/tracing), which also includes a larger list of [external tracing processors](http://openai.github.io/openai-agents-python/tracing/#external-tracing-processors-list). ## Development (only needed if you need to edit the SDK/examples) diff --git a/docs/tracing.md b/docs/tracing.md index d7d0a65..7b1ab7a 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -9,6 +9,8 @@ The Agents SDK includes built-in tracing, collecting a comprehensive record of e 1. You can globally disable tracing by setting the env var `OPENAI_AGENTS_DISABLE_TRACING=1` 2. You can disable tracing for a single run by setting [`agents.run.RunConfig.tracing_disabled`][] to `True` +***For organizations operating under a Zero Data Retention (ZDR) policy using OpenAI's APIs, tracing is unavailable.*** + ## Traces and spans - **Traces** represent a single end-to-end operation of a "workflow". They're composed of Spans. Traces have the following properties: @@ -88,10 +90,12 @@ To customize this default setup, to send traces to alternative or additional bac 1. [`add_trace_processor()`][agents.tracing.add_trace_processor] lets you add an **additional** trace processor that will receive traces and spans as they are ready. This lets you do your own processing in addition to sending traces to OpenAI's backend. 2. [`set_trace_processors()`][agents.tracing.set_trace_processors] lets you **replace** the default processors with your own trace processors. This means traces will not be sent to the OpenAI backend unless you include a `TracingProcessor` that does so. -External trace processors include: +## External tracing processors list +- [Arize-Phoenix](https://docs.arize.com/phoenix/tracing/integrations-tracing/openai-agents-sdk) - [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk) - [Pydantic Logfire](https://logfire.pydantic.dev/docs/integrations/llms/openai/#openai-agents) - [AgentOps](https://docs.agentops.ai/v1/integrations/agentssdk) -- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration)) +- [Scorecard](https://docs.scorecard.io/docs/documentation/features/tracing#openai-agents-sdk-integration) - [Keywords AI](https://docs.keywordsai.co/integration/development-frameworks/openai-agent) +- [LangSmith](https://docs.smith.langchain.com/observability/how_to_guides/trace_with_openai_agents_sdk) diff --git a/examples/agent_patterns/input_guardrails.py b/examples/agent_patterns/input_guardrails.py index 8c8e182..1545355 100644 --- a/examples/agent_patterns/input_guardrails.py +++ b/examples/agent_patterns/input_guardrails.py @@ -30,8 +30,8 @@ If the guardrail trips, we'll respond with a refusal message. ### 1. An agent-based guardrail that is triggered if the user is asking to do math homework class MathHomeworkOutput(BaseModel): - is_math_homework: bool reasoning: str + is_math_homework: bool guardrail_agent = Agent( diff --git a/examples/agent_patterns/llm_as_a_judge.py b/examples/agent_patterns/llm_as_a_judge.py index d13a67c..5a46cc3 100644 --- a/examples/agent_patterns/llm_as_a_judge.py +++ b/examples/agent_patterns/llm_as_a_judge.py @@ -23,8 +23,8 @@ story_outline_generator = Agent( @dataclass class EvaluationFeedback: - score: Literal["pass", "needs_improvement", "fail"] feedback: str + score: Literal["pass", "needs_improvement", "fail"] evaluator = Agent[None]( diff --git a/examples/basic/agent_lifecycle_example.py b/examples/basic/agent_lifecycle_example.py index bc0bbe4..29bb18c 100644 --- a/examples/basic/agent_lifecycle_example.py +++ b/examples/basic/agent_lifecycle_example.py @@ -74,7 +74,7 @@ multiply_agent = Agent( start_agent = Agent( name="Start Agent", - instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multipler agent.", + instructions="Generate a random number. If it's even, stop. If it's odd, hand off to the multiply agent.", tools=[random_number], output_type=FinalResult, handoffs=[multiply_agent], diff --git a/examples/basic/hello_world_jupyter.py b/examples/basic/hello_world_jupyter.py index bb8f14c..c929a7c 100644 --- a/examples/basic/hello_world_jupyter.py +++ b/examples/basic/hello_world_jupyter.py @@ -3,7 +3,7 @@ from agents import Agent, Runner agent = Agent(name="Assistant", instructions="You are a helpful assistant") # Intended for Jupyter notebooks where there's an existing event loop -result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704 +result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704 print(result.final_output) # Code within code loops, diff --git a/examples/handoffs/message_filter.py b/examples/handoffs/message_filter.py index 9dd56ef..b7fed6c 100644 --- a/examples/handoffs/message_filter.py +++ b/examples/handoffs/message_filter.py @@ -60,9 +60,9 @@ async def main(): print("Step 1 done") - # 2. Ask it to square a number + # 2. Ask it to generate a number result = await Runner.run( - second_agent, + first_agent, input=result.to_input_list() + [{"content": "Can you generate a random number between 0 and 100?", "role": "user"}], ) diff --git a/examples/handoffs/message_filter_streaming.py b/examples/handoffs/message_filter_streaming.py index 8d1b420..63cb1de 100644 --- a/examples/handoffs/message_filter_streaming.py +++ b/examples/handoffs/message_filter_streaming.py @@ -60,9 +60,9 @@ async def main(): print("Step 1 done") - # 2. Ask it to square a number + # 2. Ask it to generate a number result = await Runner.run( - second_agent, + first_agent, input=result.to_input_list() + [{"content": "Can you generate a random number between 0 and 100?", "role": "user"}], ) diff --git a/pyproject.toml b/pyproject.toml index ff3d01f..3ad1d37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dev = [ "mkdocstrings[python]>=0.28.0", "coverage>=7.6.12", "playwright==1.50.0", - "inline-snapshot>=0.20.5", + "inline-snapshot>=0.20.7", ] [tool.uv.workspace] members = ["agents"] @@ -118,3 +118,6 @@ filterwarnings = [ markers = [ "allow_call_model_methods: mark test as allowing calls to real model implementations", ] + +[tool.inline-snapshot] +format-command="ruff format --stdin-filename {filename}" \ No newline at end of file diff --git a/src/agents/__init__.py b/src/agents/__init__.py index a2d7f24..21a2f2a 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -73,6 +73,7 @@ from .tracing import ( SpanData, SpanError, Trace, + TracingProcessor, add_trace_processor, agent_span, custom_span, @@ -208,6 +209,7 @@ __all__ = [ "set_tracing_disabled", "trace", "Trace", + "TracingProcessor", "SpanError", "Span", "SpanData", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2c84950..c0c0ebd 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -25,7 +25,6 @@ from openai.types.responses.response_computer_tool_call import ( from openai.types.responses.response_input_param import ComputerCallOutput from openai.types.responses.response_reasoning_item import ResponseReasoningItem -from . import _utils from .agent import Agent from .agent_output import AgentOutputSchema from .computer import AsyncComputer, Computer @@ -59,6 +58,7 @@ from .tracing import ( handoff_span, trace, ) +from .util import _coro, _error_tracing if TYPE_CHECKING: from .run import RunConfig @@ -293,7 +293,7 @@ class RunImpl: elif isinstance(output, ResponseComputerToolCall): items.append(ToolCallItem(raw_item=output, agent=agent)) if not computer_tool: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Computer tool not found", data={}, @@ -324,7 +324,7 @@ class RunImpl: # Regular function tool call else: if output.name not in function_map: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Tool not found", data={"tool_name": output.name}, @@ -368,7 +368,7 @@ class RunImpl: ( agent.hooks.on_tool_start(context_wrapper, agent, func_tool) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), func_tool.on_invoke_tool(context_wrapper, tool_call.arguments), ) @@ -378,11 +378,11 @@ class RunImpl: ( agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) except Exception as e: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Error running tool", data={"tool_name": func_tool.name, "error": str(e)}, @@ -502,7 +502,7 @@ class RunImpl: source=agent, ) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) @@ -520,7 +520,7 @@ class RunImpl: new_items=tuple(new_step_items), ) if not callable(input_filter): - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( span_handoff, SpanError( message="Invalid input filter", @@ -530,7 +530,7 @@ class RunImpl: raise UserError(f"Invalid input filter: {input_filter}") filtered = input_filter(handoff_input_data) if not isinstance(filtered, HandoffInputData): - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( span_handoff, SpanError( message="Invalid input filter result", @@ -591,7 +591,7 @@ class RunImpl: hooks.on_agent_end(context_wrapper, agent, final_output), agent.hooks.on_end(context_wrapper, agent, final_output) if agent.hooks - else _utils.noop_coroutine(), + else _coro.noop_coroutine(), ) @classmethod @@ -706,7 +706,7 @@ class ComputerAction: ( agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), output_func, ) @@ -716,7 +716,7 @@ class ComputerAction: ( agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) diff --git a/src/agents/_utils.py b/src/agents/_utils.py deleted file mode 100644 index 2a0293a..0000000 --- a/src/agents/_utils.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -import re -from collections.abc import Awaitable -from typing import Any, Literal, Union - -from pydantic import TypeAdapter, ValidationError -from typing_extensions import TypeVar - -from .exceptions import ModelBehaviorError -from .logger import logger -from .tracing import Span, SpanError, get_current_span - -T = TypeVar("T") - -MaybeAwaitable = Union[Awaitable[T], T] - - -def transform_string_function_style(name: str) -> str: - # Replace spaces with underscores - name = name.replace(" ", "_") - - # Replace non-alphanumeric characters with underscores - name = re.sub(r"[^a-zA-Z0-9]", "_", name) - - return name.lower() - - -def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T: - partial_setting: bool | Literal["off", "on", "trailing-strings"] = ( - "trailing-strings" if partial else False - ) - try: - validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting) - return validated - except ValidationError as e: - attach_error_to_current_span( - SpanError( - message="Invalid JSON provided", - data={}, - ) - ) - raise ModelBehaviorError( - f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}" - ) from e - - -def attach_error_to_span(span: Span[Any], error: SpanError) -> None: - span.set_error(error) - - -def attach_error_to_current_span(error: SpanError) -> None: - span = get_current_span() - if span: - attach_error_to_span(span, error) - else: - logger.warning(f"No span to add error {error} to") - - -async def noop_coroutine() -> None: - pass diff --git a/src/agents/agent.py b/src/agents/agent.py index 61c0a89..3c4588e 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -6,8 +6,6 @@ from collections.abc import Awaitable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, cast -from . import _utils -from ._utils import MaybeAwaitable from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff from .items import ItemHelpers @@ -16,6 +14,8 @@ from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext from .tool import Tool, function_tool +from .util import _transforms +from .util._types import MaybeAwaitable if TYPE_CHECKING: from .lifecycle import AgentHooks @@ -27,8 +27,8 @@ class Agent(Generic[TContext]): """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more. We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In - addition, you can pass `description`, which is a human-readable description of the agent, used - when the agent is used inside tools/handoffs. + addition, you can pass `handoff_description`, which is a human-readable description of the + agent, used when the agent is used inside tools/handoffs. Agents are generic on the context type. The context is a (mutable) object you create. It is passed to tool functions, handoffs, guardrails, etc. @@ -126,7 +126,7 @@ class Agent(Generic[TContext]): """ @function_tool( - name_override=tool_name or _utils.transform_string_function_style(self.name), + name_override=tool_name or _transforms.transform_string_function_style(self.name), description_override=tool_description or "", ) async def run_agent(context: RunContextWrapper, input: str) -> str: diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 0c28800..3262c57 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -4,10 +4,10 @@ from typing import Any from pydantic import BaseModel, TypeAdapter from typing_extensions import TypedDict, get_args, get_origin -from . import _utils from .exceptions import ModelBehaviorError, UserError from .strict_schema import ensure_strict_json_schema from .tracing import SpanError +from .util import _error_tracing, _json _WRAPPER_DICT_KEY = "response" @@ -87,10 +87,10 @@ class AgentOutputSchema: """Validate a JSON string against the output type. Returns the validated object, or raises a `ModelBehaviorError` if the JSON is invalid. """ - validated = _utils.validate_json(json_str, self._type_adapter, partial) + validated = _json.validate_json(json_str, self._type_adapter, partial) if self._is_wrapped: if not isinstance(validated, dict): - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Invalid JSON", data={"details": f"Expected a dict, got {type(validated)}"}, @@ -101,7 +101,7 @@ class AgentOutputSchema: ) if _WRAPPER_DICT_KEY not in validated: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Invalid JSON", data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"}, diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index a4b5767..681affc 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -33,6 +33,9 @@ class FuncSchema: """The signature of the function.""" takes_context: bool = False """Whether the function takes a RunContextWrapper argument (must be the first argument).""" + strict_json_schema: bool = True + """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, + as it increases the likelihood of correct JSON input.""" def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: """ @@ -337,4 +340,5 @@ def function_schema( params_json_schema=json_schema, signature=sig, takes_context=takes_context, + strict_json_schema=strict_json_schema, ) diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 5bebcd6..a96f0f7 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload from typing_extensions import TypeVar -from ._utils import MaybeAwaitable from .exceptions import UserError from .items import TResponseInputItem from .run_context import RunContextWrapper, TContext +from .util._types import MaybeAwaitable if TYPE_CHECKING: from .agent import Agent diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index ac15740..686191f 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload from pydantic import TypeAdapter from typing_extensions import TypeAlias, TypeVar -from . import _utils from .exceptions import ModelBehaviorError, UserError from .items import RunItem, TResponseInputItem from .run_context import RunContextWrapper, TContext from .strict_schema import ensure_strict_json_schema from .tracing.spans import SpanError +from .util import _error_tracing, _json, _transforms if TYPE_CHECKING: from .agent import Agent @@ -104,7 +104,7 @@ class Handoff(Generic[TContext]): @classmethod def default_tool_name(cls, agent: Agent[Any]) -> str: - return _utils.transform_string_function_style(f"transfer_to_{agent.name}") + return _transforms.transform_string_function_style(f"transfer_to_{agent.name}") @classmethod def default_tool_description(cls, agent: Agent[Any]) -> str: @@ -192,7 +192,7 @@ def handoff( ) -> Agent[Any]: if input_type is not None and type_adapter is not None: if input_json is None: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Handoff function expected non-null input, but got None", data={"details": "input_json is None"}, @@ -200,7 +200,7 @@ def handoff( ) raise ModelBehaviorError("Handoff function expected non-null input, but got None") - validated_input = _utils.validate_json( + validated_input = _json.validate_json( json_str=input_json, type_adapter=type_adapter, partial=False, diff --git a/src/agents/result.py b/src/agents/result.py index 6e806b7..40a6480 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -17,6 +17,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger from .stream_events import StreamEvent from .tracing import Trace +from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming if TYPE_CHECKING: from ._run_impl import QueueCompleteSentinel @@ -89,6 +90,9 @@ class RunResult(RunResultBase): """The last agent that was run.""" return self._last_agent + def __str__(self) -> str: + return pretty_print_result(self) + @dataclass class RunResultStreaming(RunResultBase): @@ -216,3 +220,6 @@ class RunResultStreaming(RunResultBase): if self._output_guardrails_task and not self._output_guardrails_task.done(): self._output_guardrails_task.cancel() + + def __str__(self) -> str: + return pretty_print_run_result_streaming(self) diff --git a/src/agents/run.py b/src/agents/run.py index dfff7e3..934400f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -7,7 +7,6 @@ from typing import Any, cast from openai.types.responses import ResponseCompletedEvent -from . import Model, _utils from ._run_impl import ( NextStepFinalOutput, NextStepHandoff, @@ -33,7 +32,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger from .model_settings import ModelSettings -from .models.interface import ModelProvider +from .models.interface import Model, ModelProvider from .models.openai_provider import OpenAIProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext @@ -41,6 +40,7 @@ from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData from .usage import Usage +from .util import _coro, _error_tracing DEFAULT_MAX_TURNS = 10 @@ -193,7 +193,7 @@ class Runner: current_turn += 1 if current_turn > max_turns: - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( current_span, SpanError( message="Max turns exceeded", @@ -447,7 +447,7 @@ class Runner: for done in asyncio.as_completed(guardrail_tasks): result = await done if result.output.tripwire_triggered: - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( parent_span, SpanError( message="Guardrail tripwire triggered", @@ -511,7 +511,7 @@ class Runner: streamed_result.current_turn = current_turn if current_turn > max_turns: - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( current_span, SpanError( message="Max turns exceeded", @@ -583,7 +583,7 @@ class Runner: pass except Exception as e: if current_span: - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( current_span, SpanError( message="Error in agent run", @@ -615,7 +615,7 @@ class Runner: ( agent.hooks.on_start(context_wrapper, agent) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) @@ -705,7 +705,7 @@ class Runner: ( agent.hooks.on_start(context_wrapper, agent) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) @@ -796,7 +796,7 @@ class Runner: # Cancel all guardrail tasks if a tripwire is triggered. for t in guardrail_tasks: t.cancel() - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Guardrail tripwire triggered", data={"guardrail": result.guardrail.get_name()}, @@ -834,7 +834,7 @@ class Runner: # Cancel all guardrail tasks if a tripwire is triggered. for t in guardrail_tasks: t.cancel() - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Guardrail tripwire triggered", data={"guardrail": result.guardrail.get_name()}, diff --git a/src/agents/tool.py b/src/agents/tool.py index 7587268..0baf2c0 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -11,14 +11,15 @@ from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError from typing_extensions import Concatenate, ParamSpec -from . import _debug, _utils -from ._utils import MaybeAwaitable +from . import _debug from .computer import AsyncComputer, Computer from .exceptions import ModelBehaviorError from .function_schema import DocstringStyle, function_schema from .logger import logger from .run_context import RunContextWrapper from .tracing import SpanError +from .util import _error_tracing +from .util._types import MaybeAwaitable ToolParams = ParamSpec("ToolParams") @@ -137,6 +138,7 @@ def function_tool( docstring_style: DocstringStyle | None = None, use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, + strict_mode: bool = True, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -150,6 +152,7 @@ def function_tool( docstring_style: DocstringStyle | None = None, use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, + strict_mode: bool = True, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -163,6 +166,7 @@ def function_tool( docstring_style: DocstringStyle | None = None, use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = default_tool_error_function, + strict_mode: bool = True, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -186,6 +190,8 @@ def function_tool( failure_error_function: If provided, use this function to generate an error message when the tool call fails. The error message is sent to the LLM. If you pass None, then no error message will be sent and instead an Exception will be raised. + strict_mode: If False, parameters with default values become optional in the + function schema. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -195,6 +201,7 @@ def function_tool( description_override=description_override, docstring_style=docstring_style, use_docstring_info=use_docstring_info, + strict_json_schema=strict_mode, ) async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str: @@ -257,7 +264,7 @@ def function_tool( if inspect.isawaitable(result): return await result - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Error running tool (non-fatal)", data={ @@ -273,6 +280,7 @@ def function_tool( description=schema.description or "", params_json_schema=schema.params_json_schema, on_invoke_tool=_on_invoke_tool, + strict_json_schema=strict_mode, ) # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/src/agents/util/__init__.py b/src/agents/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/agents/util/_coro.py b/src/agents/util/_coro.py new file mode 100644 index 0000000..647ab86 --- /dev/null +++ b/src/agents/util/_coro.py @@ -0,0 +1,2 @@ +async def noop_coroutine() -> None: + pass diff --git a/src/agents/util/_error_tracing.py b/src/agents/util/_error_tracing.py new file mode 100644 index 0000000..09dbb1d --- /dev/null +++ b/src/agents/util/_error_tracing.py @@ -0,0 +1,16 @@ +from typing import Any + +from ..logger import logger +from ..tracing import Span, SpanError, get_current_span + + +def attach_error_to_span(span: Span[Any], error: SpanError) -> None: + span.set_error(error) + + +def attach_error_to_current_span(error: SpanError) -> None: + span = get_current_span() + if span: + attach_error_to_span(span, error) + else: + logger.warning(f"No span to add error {error} to") diff --git a/src/agents/util/_json.py b/src/agents/util/_json.py new file mode 100644 index 0000000..1e081f6 --- /dev/null +++ b/src/agents/util/_json.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Literal + +from pydantic import TypeAdapter, ValidationError +from typing_extensions import TypeVar + +from ..exceptions import ModelBehaviorError +from ..tracing import SpanError +from ._error_tracing import attach_error_to_current_span + +T = TypeVar("T") + + +def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T: + partial_setting: bool | Literal["off", "on", "trailing-strings"] = ( + "trailing-strings" if partial else False + ) + try: + validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting) + return validated + except ValidationError as e: + attach_error_to_current_span( + SpanError( + message="Invalid JSON provided", + data={}, + ) + ) + raise ModelBehaviorError( + f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}" + ) from e diff --git a/src/agents/util/_pretty_print.py b/src/agents/util/_pretty_print.py new file mode 100644 index 0000000..afd3e2b --- /dev/null +++ b/src/agents/util/_pretty_print.py @@ -0,0 +1,56 @@ +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +if TYPE_CHECKING: + from ..result import RunResult, RunResultBase, RunResultStreaming + + +def _indent(text: str, indent_level: int) -> str: + indent_string = " " * indent_level + return "\n".join(f"{indent_string}{line}" for line in text.splitlines()) + + +def _final_output_str(result: "RunResultBase") -> str: + if result.final_output is None: + return "None" + elif isinstance(result.final_output, str): + return result.final_output + elif isinstance(result.final_output, BaseModel): + return result.final_output.model_dump_json(indent=2) + else: + return str(result.final_output) + + +def pretty_print_result(result: "RunResult") -> str: + output = "RunResult:" + output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)' + output += ( + f"\n- Final output ({type(result.final_output).__name__}):\n" + f"{_indent(_final_output_str(result), 2)}" + ) + output += f"\n- {len(result.new_items)} new item(s)" + output += f"\n- {len(result.raw_responses)} raw response(s)" + output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)" + output += f"\n- {len(result.output_guardrail_results)} output guardrail result(s)" + output += "\n(See `RunResult` for more details)" + + return output + + +def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str: + output = "RunResultStreaming:" + output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)' + output += f"\n- Current turn: {result.current_turn}" + output += f"\n- Max turns: {result.max_turns}" + output += f"\n- Is complete: {result.is_complete}" + output += ( + f"\n- Final output ({type(result.final_output).__name__}):\n" + f"{_indent(_final_output_str(result), 2)}" + ) + output += f"\n- {len(result.new_items)} new item(s)" + output += f"\n- {len(result.raw_responses)} raw response(s)" + output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)" + output += f"\n- {len(result.output_guardrail_results)} output guardrail result(s)" + output += "\n(See `RunResultStreaming` for more details)" + return output diff --git a/src/agents/util/_transforms.py b/src/agents/util/_transforms.py new file mode 100644 index 0000000..b303074 --- /dev/null +++ b/src/agents/util/_transforms.py @@ -0,0 +1,11 @@ +import re + + +def transform_string_function_style(name: str) -> str: + # Replace spaces with underscores + name = name.replace(" ", "_") + + # Replace non-alphanumeric characters with underscores + name = re.sub(r"[^a-zA-Z0-9]", "_", name) + + return name.lower() diff --git a/src/agents/util/_types.py b/src/agents/util/_types.py new file mode 100644 index 0000000..8571a69 --- /dev/null +++ b/src/agents/util/_types.py @@ -0,0 +1,7 @@ +from collections.abc import Awaitable +from typing import Union + +from typing_extensions import TypeVar + +T = TypeVar("T") +MaybeAwaitable = Union[Awaitable[T], T] diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..d68e067 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,25 @@ +# Tests + +Before running any tests, make sure you have `uv` installed (and ideally run `make sync` after). + +## Running tests + +``` +make tests +``` + +## Snapshots + +We use [inline-snapshots](https://15r10nk.github.io/inline-snapshot/latest/) for some tests. If your code adds new snapshot tests or breaks existing ones, you can fix/create them. After fixing/creating snapshots, run `make tests` again to verify the tests pass. + +### Fixing snapshots + +``` +make snapshots-fix +``` + +### Creating snapshots + +``` +make snapshots-update +``` diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index 3a47deb..f146ec7 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import Any +from typing import Any, Optional import pytest @@ -142,3 +142,51 @@ async def test_no_error_on_invalid_json_async(): tool = will_not_fail_on_bad_json_async result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}") assert result == "error_ModelBehaviorError" + + +@function_tool(strict_mode=False) +def optional_param_function(a: int, b: Optional[int] = None) -> str: + if b is None: + return f"{a}_no_b" + return f"{a}_{b}" + + +@pytest.mark.asyncio +async def test_optional_param_function(): + tool = optional_param_function + + input_data = {"a": 5} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "5_no_b" + + input_data = {"a": 5, "b": 10} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "5_10" + + +@function_tool(strict_mode=False) +def multiple_optional_params_function( + x: int = 42, + y: str = "hello", + z: Optional[int] = None, +) -> str: + if z is None: + return f"{x}_{y}_no_z" + return f"{x}_{y}_{z}" + + +@pytest.mark.asyncio +async def test_multiple_optional_params_function(): + tool = multiple_optional_params_function + + input_data: dict[str, Any] = {} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "42_hello_no_z" + + input_data = {"x": 10, "y": "world"} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "10_world_no_z" + + input_data = {"x": 10, "y": "world", "z": 99} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "10_world_99" diff --git a/tests/test_output_tool.py b/tests/test_output_tool.py index 31ac984..86c4b3b 100644 --- a/tests/test_output_tool.py +++ b/tests/test_output_tool.py @@ -4,8 +4,9 @@ import pytest from pydantic import BaseModel from typing_extensions import TypedDict -from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError, _utils +from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError from agents.agent_output import _WRAPPER_DICT_KEY +from agents.util import _json def test_plain_text_output(): @@ -77,7 +78,7 @@ def test_bad_json_raises_error(mocker): output_schema = Runner._get_output_schema(agent) assert output_schema, "Should have an output tool config with a structured output type" - mock_validate_json = mocker.patch.object(_utils, "validate_json") + mock_validate_json = mocker.patch.object(_json, "validate_json") mock_validate_json.return_value = ["foo"] with pytest.raises(ModelBehaviorError): @@ -111,3 +112,4 @@ def test_setting_strict_false_works(): output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False) assert not output_wrapper.strict_json_schema assert output_wrapper.json_schema() == Foo.model_json_schema() + assert output_wrapper.json_schema() == Foo.model_json_schema() diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py new file mode 100644 index 0000000..b2218a2 --- /dev/null +++ b/tests/test_pretty_print.py @@ -0,0 +1,201 @@ +import json + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from agents import Agent, Runner +from agents.agent_output import _WRAPPER_DICT_KEY +from agents.util._pretty_print import pretty_print_result, pretty_print_run_result_streaming +from tests.fake_model import FakeModel + +from .test_responses import get_final_output_message, get_text_message + + +@pytest.mark.asyncio +async def test_pretty_result(): + model = FakeModel() + model.set_next_output([get_text_message("Hi there")]) + + agent = Agent(name="test_agent", model=model) + result = await Runner.run(agent, input="Hello") + + assert pretty_print_result(result) == snapshot("""\ +RunResult: +- Last agent: Agent(name="test_agent", ...) +- Final output (str): + Hi there +- 1 new item(s) +- 1 raw response(s) +- 0 input guardrail result(s) +- 0 output guardrail result(s) +(See `RunResult` for more details)\ +""") + + +@pytest.mark.asyncio +async def test_pretty_run_result_streaming(): + model = FakeModel() + model.set_next_output([get_text_message("Hi there")]) + + agent = Agent(name="test_agent", model=model) + result = Runner.run_streamed(agent, input="Hello") + async for _ in result.stream_events(): + pass + + assert pretty_print_run_result_streaming(result) == snapshot("""\ +RunResultStreaming: +- Current agent: Agent(name="test_agent", ...) +- Current turn: 1 +- Max turns: 10 +- Is complete: True +- Final output (str): + Hi there +- 1 new item(s) +- 1 raw response(s) +- 0 input guardrail result(s) +- 0 output guardrail result(s) +(See `RunResultStreaming` for more details)\ +""") + + +class Foo(BaseModel): + bar: str + + +@pytest.mark.asyncio +async def test_pretty_run_result_structured_output(): + model = FakeModel() + model.set_next_output( + [ + get_text_message("Test"), + get_final_output_message(Foo(bar="Hi there").model_dump_json()), + ] + ) + + agent = Agent(name="test_agent", model=model, output_type=Foo) + result = await Runner.run(agent, input="Hello") + + assert pretty_print_result(result) == snapshot("""\ +RunResult: +- Last agent: Agent(name="test_agent", ...) +- Final output (Foo): + { + "bar": "Hi there" + } +- 2 new item(s) +- 1 raw response(s) +- 0 input guardrail result(s) +- 0 output guardrail result(s) +(See `RunResult` for more details)\ +""") + + +@pytest.mark.asyncio +async def test_pretty_run_result_streaming_structured_output(): + model = FakeModel() + model.set_next_output( + [ + get_text_message("Test"), + get_final_output_message(Foo(bar="Hi there").model_dump_json()), + ] + ) + + agent = Agent(name="test_agent", model=model, output_type=Foo) + result = Runner.run_streamed(agent, input="Hello") + + async for _ in result.stream_events(): + pass + + assert pretty_print_run_result_streaming(result) == snapshot("""\ +RunResultStreaming: +- Current agent: Agent(name="test_agent", ...) +- Current turn: 1 +- Max turns: 10 +- Is complete: True +- Final output (Foo): + { + "bar": "Hi there" + } +- 2 new item(s) +- 1 raw response(s) +- 0 input guardrail result(s) +- 0 output guardrail result(s) +(See `RunResultStreaming` for more details)\ +""") + + +@pytest.mark.asyncio +async def test_pretty_run_result_list_structured_output(): + model = FakeModel() + model.set_next_output( + [ + get_text_message("Test"), + get_final_output_message( + json.dumps( + { + _WRAPPER_DICT_KEY: [ + Foo(bar="Hi there").model_dump(), + Foo(bar="Hi there 2").model_dump(), + ] + } + ) + ), + ] + ) + + agent = Agent(name="test_agent", model=model, output_type=list[Foo]) + result = await Runner.run(agent, input="Hello") + + assert pretty_print_result(result) == snapshot("""\ +RunResult: +- Last agent: Agent(name="test_agent", ...) +- Final output (list): + [Foo(bar='Hi there'), Foo(bar='Hi there 2')] +- 2 new item(s) +- 1 raw response(s) +- 0 input guardrail result(s) +- 0 output guardrail result(s) +(See `RunResult` for more details)\ +""") + + +@pytest.mark.asyncio +async def test_pretty_run_result_streaming_list_structured_output(): + model = FakeModel() + model.set_next_output( + [ + get_text_message("Test"), + get_final_output_message( + json.dumps( + { + _WRAPPER_DICT_KEY: [ + Foo(bar="Test").model_dump(), + Foo(bar="Test 2").model_dump(), + ] + } + ) + ), + ] + ) + + agent = Agent(name="test_agent", model=model, output_type=list[Foo]) + result = Runner.run_streamed(agent, input="Hello") + + async for _ in result.stream_events(): + pass + + assert pretty_print_run_result_streaming(result) == snapshot("""\ +RunResultStreaming: +- Current agent: Agent(name="test_agent", ...) +- Current turn: 1 +- Max turns: 10 +- Is complete: True +- Final output (list): + [Foo(bar='Test'), Foo(bar='Test 2')] +- 2 new item(s) +- 1 raw response(s) +- 0 input guardrail result(s) +- 0 output guardrail result(s) +(See `RunResultStreaming` for more details)\ +""") diff --git a/uv.lock b/uv.lock index 40f0553..2c2e05b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.9" [[package]] @@ -411,7 +412,7 @@ wheels = [ [[package]] name = "inline-snapshot" -version = "0.20.5" +version = "0.20.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "asttokens" }, @@ -419,9 +420,9 @@ dependencies = [ { name = "rich" }, { name = "tomli", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3b/95/9b85a63031c168dd1c479f8cfd5cae42d42d6ac41c18dd760a104bc87ddc/inline_snapshot-0.20.5.tar.gz", hash = "sha256:d8b67c6d533c0a3f566e72608144b54da65dc3da5d0dba4169b2c56b75530fb5", size = 92215 } +sdist = { url = "https://files.pythonhosted.org/packages/b0/41/9bd2ecd10ef789e8aff6fb68dcc7677dc31b33b2d27c306c0d40fc982fbc/inline_snapshot-0.20.7.tar.gz", hash = "sha256:d55bbb6254d0727dc304729ca7998cde1c1e984c4bf50281514aa9d727a56cf2", size = 92643 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d4/71/34e775bbf0bcf81d588d80a1df93437f937b0df9a841f246606a03fc5eff/inline_snapshot-0.20.5-py3-none-any.whl", hash = "sha256:3aa56acf5985d89f17ebd4df4aef00faacc49f10cdf4e6b42be701ffc9702b5a", size = 48071 }, + { url = "https://files.pythonhosted.org/packages/01/8f/1bf23da63ad1a0b14ca2d9114700123ef76732e375548f4f9ca94052817e/inline_snapshot-0.20.7-py3-none-any.whl", hash = "sha256:2df6dd8710d1f0def2c1f9d6c25fd03d7beba01f3addf52fc370343d9ee9959f", size = 48108 }, ] [[package]] @@ -855,7 +856,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "coverage", specifier = ">=7.6.12" }, - { name = "inline-snapshot", specifier = ">=0.20.5" }, + { name = "inline-snapshot", specifier = ">=0.20.7" }, { name = "mkdocs", specifier = ">=1.6.0" }, { name = "mkdocs-material", specifier = ">=9.6.0" }, { name = "mkdocstrings", extras = ["python"], specifier = ">=0.28.0" },