Allow replacing AgentRunner and TraceProvider (#720)

This commit is contained in:
pakrym-oai 2025-06-17 17:41:10 -07:00 committed by GitHub
parent 901d2ac57c
commit 0cf503e1c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 2290 additions and 2036 deletions

View file

@ -104,6 +104,7 @@ from .tracing import (
handoff_span, handoff_span,
mcp_tools_span, mcp_tools_span,
set_trace_processors, set_trace_processors,
set_trace_provider,
set_tracing_disabled, set_tracing_disabled,
set_tracing_export_api_key, set_tracing_export_api_key,
speech_group_span, speech_group_span,
@ -246,6 +247,7 @@ __all__ = [
"guardrail_span", "guardrail_span",
"handoff_span", "handoff_span",
"set_trace_processors", "set_trace_processors",
"set_trace_provider",
"set_tracing_disabled", "set_tracing_disabled",
"speech_group_span", "speech_group_span",
"transcription_span", "transcription_span",

View file

@ -3,12 +3,13 @@ from __future__ import annotations
import asyncio import asyncio
import copy import copy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, cast from typing import Any, Generic, cast
from openai.types.responses import ResponseCompletedEvent from openai.types.responses import ResponseCompletedEvent
from openai.types.responses.response_prompt_param import ( from openai.types.responses.response_prompt_param import (
ResponsePromptParam, ResponsePromptParam,
) )
from typing_extensions import NotRequired, TypedDict, Unpack
from ._run_impl import ( from ._run_impl import (
AgentToolUseTracker, AgentToolUseTracker,
@ -31,7 +32,12 @@ from .exceptions import (
OutputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered,
RunErrorDetails, RunErrorDetails,
) )
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult from .guardrail import (
InputGuardrail,
InputGuardrailResult,
OutputGuardrail,
OutputGuardrailResult,
)
from .handoffs import Handoff, HandoffInputFilter, handoff from .handoffs import Handoff, HandoffInputFilter, handoff
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .lifecycle import RunHooks from .lifecycle import RunHooks
@ -50,6 +56,27 @@ from .util import _coro, _error_tracing
DEFAULT_MAX_TURNS = 10 DEFAULT_MAX_TURNS = 10
DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore
# the value is set at the end of the module
def set_default_agent_runner(runner: AgentRunner | None) -> None:
"""
WARNING: this class is experimental and not part of the public API
It should not be used directly.
"""
global DEFAULT_AGENT_RUNNER
DEFAULT_AGENT_RUNNER = runner or AgentRunner()
def get_default_agent_runner() -> AgentRunner:
"""
WARNING: this class is experimental and not part of the public API
It should not be used directly.
"""
global DEFAULT_AGENT_RUNNER
return DEFAULT_AGENT_RUNNER
@dataclass @dataclass
class RunConfig: class RunConfig:
@ -110,6 +137,25 @@ class RunConfig:
""" """
class RunOptions(TypedDict, Generic[TContext]):
"""Arguments for ``AgentRunner`` methods."""
context: NotRequired[TContext | None]
"""The context for the run."""
max_turns: NotRequired[int]
"""The maximum number of turns to run for."""
hooks: NotRequired[RunHooks[TContext] | None]
"""Lifecycle hooks for the run."""
run_config: NotRequired[RunConfig | None]
"""Run configuration."""
previous_response_id: NotRequired[str | None]
"""The ID of the previous response, if any."""
class Runner: class Runner:
@classmethod @classmethod
async def run( async def run(
@ -130,13 +176,10 @@ class Runner:
`agent.output_type`, the loop terminates. `agent.output_type`, the loop terminates.
3. If there's a handoff, we run the loop again, with the new agent. 3. If there's a handoff, we run the loop again, with the new agent.
4. Else, we run tool calls (if any), and re-run the loop. 4. Else, we run tool calls (if any), and re-run the loop.
In two cases, the agent may raise an exception: In two cases, the agent may raise an exception:
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised. 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
Note that only the first agent's input guardrails are run. Note that only the first agent's input guardrails are run.
Args: Args:
starting_agent: The starting agent to run. starting_agent: The starting agent to run.
input: The initial input to the agent. You can pass a single string for a user message, input: The initial input to the agent. You can pass a single string for a user message,
@ -148,11 +191,139 @@ class Runner:
run_config: Global settings for the entire agent run. run_config: Global settings for the entire agent run.
previous_response_id: The ID of the previous response, if using OpenAI models via the previous_response_id: The ID of the previous response, if using OpenAI models via the
Responses API, this allows you to skip passing in input from the previous turn. Responses API, this allows you to skip passing in input from the previous turn.
Returns: Returns:
A run result containing all the inputs, guardrail results and the output of the last A run result containing all the inputs, guardrail results and the output of the last
agent. Agents may perform handoffs, so we don't know the specific type of the output. agent. Agents may perform handoffs, so we don't know the specific type of the output.
""" """
runner = DEFAULT_AGENT_RUNNER
return await runner.run(
starting_agent,
input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=run_config,
previous_response_id=previous_response_id,
)
@classmethod
def run_sync(
cls,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResult:
"""Run a workflow synchronously, starting at the given agent. Note that this just wraps the
`run` method, so it will not work if there's already an event loop (e.g. inside an async
function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
the `run` method instead.
The agent will run in a loop until a final output is generated. The loop runs like so:
1. The agent is invoked with the given input.
2. If there is a final output (i.e. the agent produces something of type
`agent.output_type`, the loop terminates.
3. If there's a handoff, we run the loop again, with the new agent.
4. Else, we run tool calls (if any), and re-run the loop.
In two cases, the agent may raise an exception:
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
Note that only the first agent's input guardrails are run.
Args:
starting_agent: The starting agent to run.
input: The initial input to the agent. You can pass a single string for a user message,
or a list of input items.
context: The context to run the agent with.
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
AI invocation (including any tool calls that might occur).
hooks: An object that receives callbacks on various lifecycle events.
run_config: Global settings for the entire agent run.
previous_response_id: The ID of the previous response, if using OpenAI models via the
Responses API, this allows you to skip passing in input from the previous turn.
Returns:
A run result containing all the inputs, guardrail results and the output of the last
agent. Agents may perform handoffs, so we don't know the specific type of the output.
"""
runner = DEFAULT_AGENT_RUNNER
return runner.run_sync(
starting_agent,
input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=run_config,
previous_response_id=previous_response_id,
)
@classmethod
def run_streamed(
cls,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResultStreaming:
"""Run a workflow starting at the given agent in streaming mode. The returned result object
contains a method you can use to stream semantic events as they are generated.
The agent will run in a loop until a final output is generated. The loop runs like so:
1. The agent is invoked with the given input.
2. If there is a final output (i.e. the agent produces something of type
`agent.output_type`, the loop terminates.
3. If there's a handoff, we run the loop again, with the new agent.
4. Else, we run tool calls (if any), and re-run the loop.
In two cases, the agent may raise an exception:
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
Note that only the first agent's input guardrails are run.
Args:
starting_agent: The starting agent to run.
input: The initial input to the agent. You can pass a single string for a user message,
or a list of input items.
context: The context to run the agent with.
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
AI invocation (including any tool calls that might occur).
hooks: An object that receives callbacks on various lifecycle events.
run_config: Global settings for the entire agent run.
previous_response_id: The ID of the previous response, if using OpenAI models via the
Responses API, this allows you to skip passing in input from the previous turn.
Returns:
A result object that contains data about the run, as well as a method to stream events.
"""
runner = DEFAULT_AGENT_RUNNER
return runner.run_streamed(
starting_agent,
input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=run_config,
previous_response_id=previous_response_id,
)
class AgentRunner:
"""
WARNING: this class is experimental and not part of the public API
It should not be used directly or subclassed.
"""
async def run(
self,
starting_agent: Agent[TContext],
input: str | list[TResponseInputItem],
**kwargs: Unpack[RunOptions[TContext]],
) -> RunResult:
context = kwargs.get("context")
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
hooks = kwargs.get("hooks")
run_config = kwargs.get("run_config")
previous_response_id = kwargs.get("previous_response_id")
if hooks is None: if hooks is None:
hooks = RunHooks[Any]() hooks = RunHooks[Any]()
if run_config is None: if run_config is None:
@ -184,13 +355,15 @@ class Runner:
try: try:
while True: while True:
all_tools = await cls._get_all_tools(current_agent, context_wrapper) all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper)
# Start an agent span if we don't have one. This span is ended if the current # Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends. # agent changes, or if the agent loop ends.
if current_span is None: if current_span is None:
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] handoff_names = [
if output_schema := cls._get_output_schema(current_agent): h.agent_name for h in AgentRunner._get_handoffs(current_agent)
]
if output_schema := AgentRunner._get_output_schema(current_agent):
output_type_name = output_schema.name() output_type_name = output_schema.name()
else: else:
output_type_name = "str" output_type_name = "str"
@ -220,14 +393,14 @@ class Runner:
if current_turn == 1: if current_turn == 1:
input_guardrail_results, turn_result = await asyncio.gather( input_guardrail_results, turn_result = await asyncio.gather(
cls._run_input_guardrails( self._run_input_guardrails(
starting_agent, starting_agent,
starting_agent.input_guardrails starting_agent.input_guardrails
+ (run_config.input_guardrails or []), + (run_config.input_guardrails or []),
copy.deepcopy(input), copy.deepcopy(input),
context_wrapper, context_wrapper,
), ),
cls._run_single_turn( self._run_single_turn(
agent=current_agent, agent=current_agent,
all_tools=all_tools, all_tools=all_tools,
original_input=original_input, original_input=original_input,
@ -241,7 +414,7 @@ class Runner:
), ),
) )
else: else:
turn_result = await cls._run_single_turn( turn_result = await self._run_single_turn(
agent=current_agent, agent=current_agent,
all_tools=all_tools, all_tools=all_tools,
original_input=original_input, original_input=original_input,
@ -260,7 +433,7 @@ class Runner:
generated_items = turn_result.generated_items generated_items = turn_result.generated_items
if isinstance(turn_result.next_step, NextStepFinalOutput): if isinstance(turn_result.next_step, NextStepFinalOutput):
output_guardrail_results = await cls._run_output_guardrails( output_guardrail_results = await self._run_output_guardrails(
current_agent.output_guardrails + (run_config.output_guardrails or []), current_agent.output_guardrails + (run_config.output_guardrails or []),
current_agent, current_agent,
turn_result.next_step.output, turn_result.next_step.output,
@ -302,54 +475,19 @@ class Runner:
if current_span: if current_span:
current_span.finish(reset_current=True) current_span.finish(reset_current=True)
@classmethod
def run_sync( def run_sync(
cls, self,
starting_agent: Agent[TContext], starting_agent: Agent[TContext],
input: str | list[TResponseInputItem], input: str | list[TResponseInputItem],
*, **kwargs: Unpack[RunOptions[TContext]],
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResult: ) -> RunResult:
"""Run a workflow synchronously, starting at the given agent. Note that this just wraps the context = kwargs.get("context")
`run` method, so it will not work if there's already an event loop (e.g. inside an async max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
function, or in a Jupyter notebook or async context like FastAPI). For those cases, use hooks = kwargs.get("hooks")
the `run` method instead. run_config = kwargs.get("run_config")
previous_response_id = kwargs.get("previous_response_id")
The agent will run in a loop until a final output is generated. The loop runs like so:
1. The agent is invoked with the given input.
2. If there is a final output (i.e. the agent produces something of type
`agent.output_type`, the loop terminates.
3. If there's a handoff, we run the loop again, with the new agent.
4. Else, we run tool calls (if any), and re-run the loop.
In two cases, the agent may raise an exception:
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
Note that only the first agent's input guardrails are run.
Args:
starting_agent: The starting agent to run.
input: The initial input to the agent. You can pass a single string for a user message,
or a list of input items.
context: The context to run the agent with.
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
AI invocation (including any tool calls that might occur).
hooks: An object that receives callbacks on various lifecycle events.
run_config: Global settings for the entire agent run.
previous_response_id: The ID of the previous response, if using OpenAI models via the
Responses API, this allows you to skip passing in input from the previous turn.
Returns:
A run result containing all the inputs, guardrail results and the output of the last
agent. Agents may perform handoffs, so we don't know the specific type of the output.
"""
return asyncio.get_event_loop().run_until_complete( return asyncio.get_event_loop().run_until_complete(
cls.run( self.run(
starting_agent, starting_agent,
input, input,
context=context, context=context,
@ -360,47 +498,17 @@ class Runner:
) )
) )
@classmethod
def run_streamed( def run_streamed(
cls, self,
starting_agent: Agent[TContext], starting_agent: Agent[TContext],
input: str | list[TResponseInputItem], input: str | list[TResponseInputItem],
context: TContext | None = None, **kwargs: Unpack[RunOptions[TContext]],
max_turns: int = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
) -> RunResultStreaming: ) -> RunResultStreaming:
"""Run a workflow starting at the given agent in streaming mode. The returned result object context = kwargs.get("context")
contains a method you can use to stream semantic events as they are generated. max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
hooks = kwargs.get("hooks")
The agent will run in a loop until a final output is generated. The loop runs like so: run_config = kwargs.get("run_config")
1. The agent is invoked with the given input. previous_response_id = kwargs.get("previous_response_id")
2. If there is a final output (i.e. the agent produces something of type
`agent.output_type`, the loop terminates.
3. If there's a handoff, we run the loop again, with the new agent.
4. Else, we run tool calls (if any), and re-run the loop.
In two cases, the agent may raise an exception:
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
Note that only the first agent's input guardrails are run.
Args:
starting_agent: The starting agent to run.
input: The initial input to the agent. You can pass a single string for a user message,
or a list of input items.
context: The context to run the agent with.
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
AI invocation (including any tool calls that might occur).
hooks: An object that receives callbacks on various lifecycle events.
run_config: Global settings for the entire agent run.
previous_response_id: The ID of the previous response, if using OpenAI models via the
Responses API, this allows you to skip passing in input from the previous turn.
Returns:
A result object that contains data about the run, as well as a method to stream events.
"""
if hooks is None: if hooks is None:
hooks = RunHooks[Any]() hooks = RunHooks[Any]()
if run_config is None: if run_config is None:
@ -421,7 +529,7 @@ class Runner:
) )
) )
output_schema = cls._get_output_schema(starting_agent) output_schema = AgentRunner._get_output_schema(starting_agent)
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
context=context # type: ignore context=context # type: ignore
) )
@ -444,7 +552,7 @@ class Runner:
# Kick off the actual agent loop in the background and return the streamed result object. # Kick off the actual agent loop in the background and return the streamed result object.
streamed_result._run_impl_task = asyncio.create_task( streamed_result._run_impl_task = asyncio.create_task(
cls._run_streamed_impl( self._start_streaming(
starting_input=input, starting_input=input,
streamed_result=streamed_result, streamed_result=streamed_result,
starting_agent=starting_agent, starting_agent=starting_agent,
@ -501,7 +609,7 @@ class Runner:
streamed_result.input_guardrail_results = guardrail_results streamed_result.input_guardrail_results = guardrail_results
@classmethod @classmethod
async def _run_streamed_impl( async def _start_streaming(
cls, cls,
starting_input: str | list[TResponseInputItem], starting_input: str | list[TResponseInputItem],
streamed_result: RunResultStreaming, streamed_result: RunResultStreaming,
@ -1008,3 +1116,6 @@ class Runner:
return agent.model return agent.model
return run_config.model_provider.get_model(agent.model) return run_config.model_provider.get_model(agent.model)
DEFAULT_AGENT_RUNNER = AgentRunner()

View file

@ -1,5 +1,7 @@
import atexit import atexit
from agents.tracing.provider import DefaultTraceProvider, TraceProvider
from .create import ( from .create import (
agent_span, agent_span,
custom_span, custom_span,
@ -18,7 +20,7 @@ from .create import (
) )
from .processor_interface import TracingProcessor from .processor_interface import TracingProcessor
from .processors import default_exporter, default_processor from .processors import default_exporter, default_processor
from .setup import GLOBAL_TRACE_PROVIDER from .setup import get_trace_provider, set_trace_provider
from .span_data import ( from .span_data import (
AgentSpanData, AgentSpanData,
CustomSpanData, CustomSpanData,
@ -45,10 +47,12 @@ __all__ = [
"generation_span", "generation_span",
"get_current_span", "get_current_span",
"get_current_trace", "get_current_trace",
"get_trace_provider",
"guardrail_span", "guardrail_span",
"handoff_span", "handoff_span",
"response_span", "response_span",
"set_trace_processors", "set_trace_processors",
"set_trace_provider",
"set_tracing_disabled", "set_tracing_disabled",
"trace", "trace",
"Trace", "Trace",
@ -67,6 +71,7 @@ __all__ = [
"SpeechSpanData", "SpeechSpanData",
"TranscriptionSpanData", "TranscriptionSpanData",
"TracingProcessor", "TracingProcessor",
"TraceProvider",
"gen_trace_id", "gen_trace_id",
"gen_span_id", "gen_span_id",
"speech_group_span", "speech_group_span",
@ -80,21 +85,21 @@ def add_trace_processor(span_processor: TracingProcessor) -> None:
""" """
Adds a new trace processor. This processor will receive all traces/spans. Adds a new trace processor. This processor will receive all traces/spans.
""" """
GLOBAL_TRACE_PROVIDER.register_processor(span_processor) get_trace_provider().register_processor(span_processor)
def set_trace_processors(processors: list[TracingProcessor]) -> None: def set_trace_processors(processors: list[TracingProcessor]) -> None:
""" """
Set the list of trace processors. This will replace the current list of processors. Set the list of trace processors. This will replace the current list of processors.
""" """
GLOBAL_TRACE_PROVIDER.set_processors(processors) get_trace_provider().set_processors(processors)
def set_tracing_disabled(disabled: bool) -> None: def set_tracing_disabled(disabled: bool) -> None:
""" """
Set whether tracing is globally disabled. Set whether tracing is globally disabled.
""" """
GLOBAL_TRACE_PROVIDER.set_disabled(disabled) get_trace_provider().set_disabled(disabled)
def set_tracing_export_api_key(api_key: str) -> None: def set_tracing_export_api_key(api_key: str) -> None:
@ -104,10 +109,11 @@ def set_tracing_export_api_key(api_key: str) -> None:
default_exporter().set_api_key(api_key) default_exporter().set_api_key(api_key)
set_trace_provider(DefaultTraceProvider())
# Add the default processor, which exports traces and spans to the backend in batches. You can # Add the default processor, which exports traces and spans to the backend in batches. You can
# change the default behavior by either: # change the default behavior by either:
# 1. calling add_trace_processor(), which adds additional processors, or # 1. calling add_trace_processor(), which adds additional processors, or
# 2. calling set_trace_processors(), which replaces the default processor. # 2. calling set_trace_processors(), which replaces the default processor.
add_trace_processor(default_processor()) add_trace_processor(default_processor())
atexit.register(GLOBAL_TRACE_PROVIDER.shutdown) atexit.register(get_trace_provider().shutdown)

View file

@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from ..logger import logger from ..logger import logger
from .setup import GLOBAL_TRACE_PROVIDER from .setup import get_trace_provider
from .span_data import ( from .span_data import (
AgentSpanData, AgentSpanData,
CustomSpanData, CustomSpanData,
@ -56,13 +56,13 @@ def trace(
Returns: Returns:
The newly created trace object. The newly created trace object.
""" """
current_trace = GLOBAL_TRACE_PROVIDER.get_current_trace() current_trace = get_trace_provider().get_current_trace()
if current_trace: if current_trace:
logger.warning( logger.warning(
"Trace already exists. Creating a new trace, but this is probably a mistake." "Trace already exists. Creating a new trace, but this is probably a mistake."
) )
return GLOBAL_TRACE_PROVIDER.create_trace( return get_trace_provider().create_trace(
name=workflow_name, name=workflow_name,
trace_id=trace_id, trace_id=trace_id,
group_id=group_id, group_id=group_id,
@ -73,12 +73,12 @@ def trace(
def get_current_trace() -> Trace | None: def get_current_trace() -> Trace | None:
"""Returns the currently active trace, if present.""" """Returns the currently active trace, if present."""
return GLOBAL_TRACE_PROVIDER.get_current_trace() return get_trace_provider().get_current_trace()
def get_current_span() -> Span[Any] | None: def get_current_span() -> Span[Any] | None:
"""Returns the currently active span, if present.""" """Returns the currently active span, if present."""
return GLOBAL_TRACE_PROVIDER.get_current_span() return get_trace_provider().get_current_span()
def agent_span( def agent_span(
@ -108,7 +108,7 @@ def agent_span(
Returns: Returns:
The newly created agent span. The newly created agent span.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type), span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type),
span_id=span_id, span_id=span_id,
parent=parent, parent=parent,
@ -141,7 +141,7 @@ def function_span(
Returns: Returns:
The newly created function span. The newly created function span.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=FunctionSpanData(name=name, input=input, output=output), span_data=FunctionSpanData(name=name, input=input, output=output),
span_id=span_id, span_id=span_id,
parent=parent, parent=parent,
@ -183,7 +183,7 @@ def generation_span(
Returns: Returns:
The newly created generation span. The newly created generation span.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=GenerationSpanData( span_data=GenerationSpanData(
input=input, input=input,
output=output, output=output,
@ -215,7 +215,7 @@ def response_span(
trace/span as the parent. trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded. disabled: If True, we will return a Span but the Span will not be recorded.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=ResponseSpanData(response=response), span_data=ResponseSpanData(response=response),
span_id=span_id, span_id=span_id,
parent=parent, parent=parent,
@ -246,7 +246,7 @@ def handoff_span(
Returns: Returns:
The newly created handoff span. The newly created handoff span.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent), span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent),
span_id=span_id, span_id=span_id,
parent=parent, parent=parent,
@ -278,7 +278,7 @@ def custom_span(
Returns: Returns:
The newly created custom span. The newly created custom span.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=CustomSpanData(name=name, data=data or {}), span_data=CustomSpanData(name=name, data=data or {}),
span_id=span_id, span_id=span_id,
parent=parent, parent=parent,
@ -306,7 +306,7 @@ def guardrail_span(
trace/span as the parent. trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded. disabled: If True, we will return a Span but the Span will not be recorded.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=GuardrailSpanData(name=name, triggered=triggered), span_data=GuardrailSpanData(name=name, triggered=triggered),
span_id=span_id, span_id=span_id,
parent=parent, parent=parent,
@ -344,7 +344,7 @@ def transcription_span(
Returns: Returns:
The newly created speech-to-text span. The newly created speech-to-text span.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=TranscriptionSpanData( span_data=TranscriptionSpanData(
input=input, input=input,
input_format=input_format, input_format=input_format,
@ -386,7 +386,7 @@ def speech_span(
trace/span as the parent. trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded. disabled: If True, we will return a Span but the Span will not be recorded.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=SpeechSpanData( span_data=SpeechSpanData(
model=model, model=model,
input=input, input=input,
@ -419,7 +419,7 @@ def speech_group_span(
trace/span as the parent. trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded. disabled: If True, we will return a Span but the Span will not be recorded.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=SpeechGroupSpanData(input=input), span_data=SpeechGroupSpanData(input=input),
span_id=span_id, span_id=span_id,
parent=parent, parent=parent,
@ -447,7 +447,7 @@ def mcp_tools_span(
trace/span as the parent. trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded. disabled: If True, we will return a Span but the Span will not be recorded.
""" """
return GLOBAL_TRACE_PROVIDER.create_span( return get_trace_provider().create_span(
span_data=MCPListToolsSpanData(server=server, result=result), span_data=MCPListToolsSpanData(server=server, result=result),
span_id=span_id, span_id=span_id,
parent=parent, parent=parent,

View file

@ -0,0 +1,294 @@
from __future__ import annotations
import os
import threading
import uuid
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import Any
from ..logger import logger
from .processor_interface import TracingProcessor
from .scope import Scope
from .spans import NoOpSpan, Span, SpanImpl, TSpanData
from .traces import NoOpTrace, Trace, TraceImpl
class SynchronousMultiTracingProcessor(TracingProcessor):
"""
Forwards all calls to a list of TracingProcessors, in order of registration.
"""
def __init__(self):
# Using a tuple to avoid race conditions when iterating over processors
self._processors: tuple[TracingProcessor, ...] = ()
self._lock = threading.Lock()
def add_tracing_processor(self, tracing_processor: TracingProcessor):
"""
Add a processor to the list of processors. Each processor will receive all traces/spans.
"""
with self._lock:
self._processors += (tracing_processor,)
def set_processors(self, processors: list[TracingProcessor]):
"""
Set the list of processors. This will replace the current list of processors.
"""
with self._lock:
self._processors = tuple(processors)
def on_trace_start(self, trace: Trace) -> None:
"""
Called when a trace is started.
"""
for processor in self._processors:
processor.on_trace_start(trace)
def on_trace_end(self, trace: Trace) -> None:
"""
Called when a trace is finished.
"""
for processor in self._processors:
processor.on_trace_end(trace)
def on_span_start(self, span: Span[Any]) -> None:
"""
Called when a span is started.
"""
for processor in self._processors:
processor.on_span_start(span)
def on_span_end(self, span: Span[Any]) -> None:
"""
Called when a span is finished.
"""
for processor in self._processors:
processor.on_span_end(span)
def shutdown(self) -> None:
"""
Called when the application stops.
"""
for processor in self._processors:
logger.debug(f"Shutting down trace processor {processor}")
processor.shutdown()
def force_flush(self):
"""
Force the processors to flush their buffers.
"""
for processor in self._processors:
processor.force_flush()
class TraceProvider(ABC):
"""Interface for creating traces and spans."""
@abstractmethod
def register_processor(self, processor: TracingProcessor) -> None:
"""Add a processor that will receive all traces and spans."""
@abstractmethod
def set_processors(self, processors: list[TracingProcessor]) -> None:
"""Replace the list of processors with ``processors``."""
@abstractmethod
def get_current_trace(self) -> Trace | None:
"""Return the currently active trace, if any."""
@abstractmethod
def get_current_span(self) -> Span[Any] | None:
"""Return the currently active span, if any."""
@abstractmethod
def set_disabled(self, disabled: bool) -> None:
"""Enable or disable tracing globally."""
@abstractmethod
def time_iso(self) -> str:
"""Return the current time in ISO 8601 format."""
@abstractmethod
def gen_trace_id(self) -> str:
"""Generate a new trace identifier."""
@abstractmethod
def gen_span_id(self) -> str:
"""Generate a new span identifier."""
@abstractmethod
def gen_group_id(self) -> str:
"""Generate a new group identifier."""
@abstractmethod
def create_trace(
self,
name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""Create a new trace."""
@abstractmethod
def create_span(
self,
span_data: TSpanData,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[TSpanData]:
"""Create a new span."""
@abstractmethod
def shutdown(self) -> None:
"""Clean up any resources used by the provider."""
class DefaultTraceProvider(TraceProvider):
def __init__(self) -> None:
self._multi_processor = SynchronousMultiTracingProcessor()
self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in (
"true",
"1",
)
def register_processor(self, processor: TracingProcessor):
"""
Add a processor to the list of processors. Each processor will receive all traces/spans.
"""
self._multi_processor.add_tracing_processor(processor)
def set_processors(self, processors: list[TracingProcessor]):
"""
Set the list of processors. This will replace the current list of processors.
"""
self._multi_processor.set_processors(processors)
def get_current_trace(self) -> Trace | None:
"""
Returns the currently active trace, if any.
"""
return Scope.get_current_trace()
def get_current_span(self) -> Span[Any] | None:
"""
Returns the currently active span, if any.
"""
return Scope.get_current_span()
def set_disabled(self, disabled: bool) -> None:
"""
Set whether tracing is disabled.
"""
self._disabled = disabled
def time_iso(self) -> str:
"""Return the current time in ISO 8601 format."""
return datetime.now(timezone.utc).isoformat()
def gen_trace_id(self) -> str:
"""Generate a new trace ID."""
return f"trace_{uuid.uuid4().hex}"
def gen_span_id(self) -> str:
"""Generate a new span ID."""
return f"span_{uuid.uuid4().hex[:24]}"
def gen_group_id(self) -> str:
"""Generate a new group ID."""
return f"group_{uuid.uuid4().hex[:24]}"
def create_trace(
self,
name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""
Create a new trace.
"""
if self._disabled or disabled:
logger.debug(f"Tracing is disabled. Not creating trace {name}")
return NoOpTrace()
trace_id = trace_id or self.gen_trace_id()
logger.debug(f"Creating trace {name} with id {trace_id}")
return TraceImpl(
name=name,
trace_id=trace_id,
group_id=group_id,
metadata=metadata,
processor=self._multi_processor,
)
def create_span(
self,
span_data: TSpanData,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[TSpanData]:
"""
Create a new span.
"""
if self._disabled or disabled:
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
return NoOpSpan(span_data)
if not parent:
current_span = Scope.get_current_span()
current_trace = Scope.get_current_trace()
if current_trace is None:
logger.error(
"No active trace. Make sure to start a trace with `trace()` first"
"Returning NoOpSpan."
)
return NoOpSpan(span_data)
elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan):
logger.debug(
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
)
return NoOpSpan(span_data)
parent_id = current_span.span_id if current_span else None
trace_id = current_trace.trace_id
elif isinstance(parent, Trace):
if isinstance(parent, NoOpTrace):
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
return NoOpSpan(span_data)
trace_id = parent.trace_id
parent_id = None
elif isinstance(parent, Span):
if isinstance(parent, NoOpSpan):
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
return NoOpSpan(span_data)
parent_id = parent.span_id
trace_id = parent.trace_id
logger.debug(f"Creating span {span_data} with id {span_id}")
return SpanImpl(
trace_id=trace_id,
span_id=span_id or self.gen_span_id(),
parent_id=parent_id,
processor=self._multi_processor,
span_data=span_data,
)
def shutdown(self) -> None:
if self._disabled:
return
try:
logger.debug("Shutting down trace provider")
self._multi_processor.shutdown()
except Exception as e:
logger.error(f"Error shutting down trace provider: {e}")

View file

@ -1,214 +1,21 @@
from __future__ import annotations from __future__ import annotations
import os from typing import TYPE_CHECKING
import threading
from typing import Any
from ..logger import logger if TYPE_CHECKING:
from . import util from .provider import TraceProvider
from .processor_interface import TracingProcessor
from .scope import Scope GLOBAL_TRACE_PROVIDER: TraceProvider | None = None
from .spans import NoOpSpan, Span, SpanImpl, TSpanData
from .traces import NoOpTrace, Trace, TraceImpl
class SynchronousMultiTracingProcessor(TracingProcessor): def set_trace_provider(provider: TraceProvider) -> None:
""" """Set the global trace provider used by tracing utilities."""
Forwards all calls to a list of TracingProcessors, in order of registration. global GLOBAL_TRACE_PROVIDER
""" GLOBAL_TRACE_PROVIDER = provider
def __init__(self):
# Using a tuple to avoid race conditions when iterating over processors
self._processors: tuple[TracingProcessor, ...] = ()
self._lock = threading.Lock()
def add_tracing_processor(self, tracing_processor: TracingProcessor):
"""
Add a processor to the list of processors. Each processor will receive all traces/spans.
"""
with self._lock:
self._processors += (tracing_processor,)
def set_processors(self, processors: list[TracingProcessor]):
"""
Set the list of processors. This will replace the current list of processors.
"""
with self._lock:
self._processors = tuple(processors)
def on_trace_start(self, trace: Trace) -> None:
"""
Called when a trace is started.
"""
for processor in self._processors:
processor.on_trace_start(trace)
def on_trace_end(self, trace: Trace) -> None:
"""
Called when a trace is finished.
"""
for processor in self._processors:
processor.on_trace_end(trace)
def on_span_start(self, span: Span[Any]) -> None:
"""
Called when a span is started.
"""
for processor in self._processors:
processor.on_span_start(span)
def on_span_end(self, span: Span[Any]) -> None:
"""
Called when a span is finished.
"""
for processor in self._processors:
processor.on_span_end(span)
def shutdown(self) -> None:
"""
Called when the application stops.
"""
for processor in self._processors:
logger.debug(f"Shutting down trace processor {processor}")
processor.shutdown()
def force_flush(self):
"""
Force the processors to flush their buffers.
"""
for processor in self._processors:
processor.force_flush()
class TraceProvider: def get_trace_provider() -> TraceProvider:
def __init__(self): """Get the global trace provider used by tracing utilities."""
self._multi_processor = SynchronousMultiTracingProcessor() if GLOBAL_TRACE_PROVIDER is None:
self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in ( raise RuntimeError("Trace provider not set")
"true", return GLOBAL_TRACE_PROVIDER
"1",
)
def register_processor(self, processor: TracingProcessor):
"""
Add a processor to the list of processors. Each processor will receive all traces/spans.
"""
self._multi_processor.add_tracing_processor(processor)
def set_processors(self, processors: list[TracingProcessor]):
"""
Set the list of processors. This will replace the current list of processors.
"""
self._multi_processor.set_processors(processors)
def get_current_trace(self) -> Trace | None:
"""
Returns the currently active trace, if any.
"""
return Scope.get_current_trace()
def get_current_span(self) -> Span[Any] | None:
"""
Returns the currently active span, if any.
"""
return Scope.get_current_span()
def set_disabled(self, disabled: bool) -> None:
"""
Set whether tracing is disabled.
"""
self._disabled = disabled
def create_trace(
self,
name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""
Create a new trace.
"""
if self._disabled or disabled:
logger.debug(f"Tracing is disabled. Not creating trace {name}")
return NoOpTrace()
trace_id = trace_id or util.gen_trace_id()
logger.debug(f"Creating trace {name} with id {trace_id}")
return TraceImpl(
name=name,
trace_id=trace_id,
group_id=group_id,
metadata=metadata,
processor=self._multi_processor,
)
def create_span(
self,
span_data: TSpanData,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[TSpanData]:
"""
Create a new span.
"""
if self._disabled or disabled:
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
return NoOpSpan(span_data)
if not parent:
current_span = Scope.get_current_span()
current_trace = Scope.get_current_trace()
if current_trace is None:
logger.error(
"No active trace. Make sure to start a trace with `trace()` first"
"Returning NoOpSpan."
)
return NoOpSpan(span_data)
elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan):
logger.debug(
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
)
return NoOpSpan(span_data)
parent_id = current_span.span_id if current_span else None
trace_id = current_trace.trace_id
elif isinstance(parent, Trace):
if isinstance(parent, NoOpTrace):
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
return NoOpSpan(span_data)
trace_id = parent.trace_id
parent_id = None
elif isinstance(parent, Span):
if isinstance(parent, NoOpSpan):
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
return NoOpSpan(span_data)
parent_id = parent.span_id
trace_id = parent.trace_id
logger.debug(f"Creating span {span_data} with id {span_id}")
return SpanImpl(
trace_id=trace_id,
span_id=span_id,
parent_id=parent_id,
processor=self._multi_processor,
span_data=span_data,
)
def shutdown(self) -> None:
if self._disabled:
return
try:
logger.debug("Shutting down trace provider")
self._multi_processor.shutdown()
except Exception as e:
logger.error(f"Error shutting down trace provider: {e}")
GLOBAL_TRACE_PROVIDER = TraceProvider()

View file

@ -1,22 +1,21 @@
import uuid from .setup import get_trace_provider
from datetime import datetime, timezone
def time_iso() -> str: def time_iso() -> str:
"""Returns the current time in ISO 8601 format.""" """Return the current time in ISO 8601 format."""
return datetime.now(timezone.utc).isoformat() return get_trace_provider().time_iso()
def gen_trace_id() -> str: def gen_trace_id() -> str:
"""Generates a new trace ID.""" """Generate a new trace ID."""
return f"trace_{uuid.uuid4().hex}" return get_trace_provider().gen_trace_id()
def gen_span_id() -> str: def gen_span_id() -> str:
"""Generates a new span ID.""" """Generate a new span ID."""
return f"span_{uuid.uuid4().hex[:24]}" return get_trace_provider().gen_span_id()
def gen_group_id() -> str: def gen_group_id() -> str:
"""Generates a new group ID.""" """Generate a new group ID."""
return f"group_{uuid.uuid4().hex[:24]}" return get_trace_provider().gen_group_id()

View file

@ -5,8 +5,9 @@ import pytest
from agents.models import _openai_shared from agents.models import _openai_shared
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
from agents.models.openai_responses import OpenAIResponsesModel from agents.models.openai_responses import OpenAIResponsesModel
from agents.run import set_default_agent_runner
from agents.tracing import set_trace_processors from agents.tracing import set_trace_processors
from agents.tracing.setup import GLOBAL_TRACE_PROVIDER from agents.tracing.setup import get_trace_provider
from .testing_processor import SPAN_PROCESSOR_TESTING from .testing_processor import SPAN_PROCESSOR_TESTING
@ -33,11 +34,16 @@ def clear_openai_settings():
_openai_shared._use_responses_by_default = True _openai_shared._use_responses_by_default = True
@pytest.fixture(autouse=True)
def clear_default_runner():
set_default_agent_runner(None)
# This fixture will run after all tests end # This fixture will run after all tests end
@pytest.fixture(autouse=True, scope="session") @pytest.fixture(autouse=True, scope="session")
def shutdown_trace_provider(): def shutdown_trace_provider():
yield yield
GLOBAL_TRACE_PROVIDER.shutdown() get_trace_provider().shutdown()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View file

@ -1,20 +1,21 @@
from agents import Agent, OpenAIResponsesModel, RunConfig, Runner from agents import Agent, OpenAIResponsesModel, RunConfig
from agents.extensions.models.litellm_model import LitellmModel from agents.extensions.models.litellm_model import LitellmModel
from agents.run import AgentRunner
def test_no_prefix_is_openai(): def test_no_prefix_is_openai():
agent = Agent(model="gpt-4o", instructions="", name="test") agent = Agent(model="gpt-4o", instructions="", name="test")
model = Runner._get_model(agent, RunConfig()) model = AgentRunner._get_model(agent, RunConfig())
assert isinstance(model, OpenAIResponsesModel) assert isinstance(model, OpenAIResponsesModel)
def openai_prefix_is_openai(): def openai_prefix_is_openai():
agent = Agent(model="openai/gpt-4o", instructions="", name="test") agent = Agent(model="openai/gpt-4o", instructions="", name="test")
model = Runner._get_model(agent, RunConfig()) model = AgentRunner._get_model(agent, RunConfig())
assert isinstance(model, OpenAIResponsesModel) assert isinstance(model, OpenAIResponsesModel)
def test_litellm_prefix_is_litellm(): def test_litellm_prefix_is_litellm():
agent = Agent(model="litellm/foo/bar", instructions="", name="test") agent = Agent(model="litellm/foo/bar", instructions="", name="test")
model = Runner._get_model(agent, RunConfig()) model = AgentRunner._get_model(agent, RunConfig())
assert isinstance(model, LitellmModel) assert isinstance(model, LitellmModel)

View file

@ -1,7 +1,8 @@
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, Runner, handoff from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, handoff
from agents.run import AgentRunner
@pytest.mark.asyncio @pytest.mark.asyncio
@ -42,7 +43,7 @@ async def test_handoff_with_agents():
handoffs=[agent_1, agent_2], handoffs=[agent_1, agent_2],
) )
handoffs = Runner._get_handoffs(agent_3) handoffs = AgentRunner._get_handoffs(agent_3)
assert len(handoffs) == 2 assert len(handoffs) == 2
assert handoffs[0].agent_name == "agent_1" assert handoffs[0].agent_name == "agent_1"
@ -77,7 +78,7 @@ async def test_handoff_with_handoff_obj():
], ],
) )
handoffs = Runner._get_handoffs(agent_3) handoffs = AgentRunner._get_handoffs(agent_3)
assert len(handoffs) == 2 assert len(handoffs) == 2
assert handoffs[0].agent_name == "agent_1" assert handoffs[0].agent_name == "agent_1"
@ -111,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent():
handoffs=[handoff(agent_1), agent_2], handoffs=[handoff(agent_1), agent_2],
) )
handoffs = Runner._get_handoffs(agent_3) handoffs = AgentRunner._get_handoffs(agent_3)
assert len(handoffs) == 2 assert len(handoffs) == 2
assert handoffs[0].agent_name == "agent_1" assert handoffs[0].agent_name == "agent_1"
@ -159,7 +160,7 @@ async def test_agent_final_output():
output_type=Foo, output_type=Foo,
) )
schema = Runner._get_output_schema(agent) schema = AgentRunner._get_output_schema(agent)
assert isinstance(schema, AgentOutputSchema) assert isinstance(schema, AgentOutputSchema)
assert schema is not None assert schema is not None
assert schema.output_type == Foo assert schema.output_type == Foo

View file

@ -12,10 +12,10 @@ from agents import (
MessageOutputItem, MessageOutputItem,
ModelBehaviorError, ModelBehaviorError,
RunContextWrapper, RunContextWrapper,
Runner,
UserError, UserError,
handoff, handoff,
) )
from agents.run import AgentRunner
def message_item(content: str, agent: Agent[Any]) -> MessageOutputItem: def message_item(content: str, agent: Agent[Any]) -> MessageOutputItem:
@ -45,9 +45,9 @@ def test_single_handoff_setup():
assert not agent_1.handoffs assert not agent_1.handoffs
assert agent_2.handoffs == [agent_1] assert agent_2.handoffs == [agent_1]
assert not Runner._get_handoffs(agent_1) assert not AgentRunner._get_handoffs(agent_1)
handoff_objects = Runner._get_handoffs(agent_2) handoff_objects = AgentRunner._get_handoffs(agent_2)
assert len(handoff_objects) == 1 assert len(handoff_objects) == 1
obj = handoff_objects[0] obj = handoff_objects[0]
assert obj.tool_name == Handoff.default_tool_name(agent_1) assert obj.tool_name == Handoff.default_tool_name(agent_1)
@ -64,7 +64,7 @@ def test_multiple_handoffs_setup():
assert not agent_1.handoffs assert not agent_1.handoffs
assert not agent_2.handoffs assert not agent_2.handoffs
handoff_objects = Runner._get_handoffs(agent_3) handoff_objects = AgentRunner._get_handoffs(agent_3)
assert len(handoff_objects) == 2 assert len(handoff_objects) == 2
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1) assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2) assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
@ -95,7 +95,7 @@ def test_custom_handoff_setup():
assert not agent_1.handoffs assert not agent_1.handoffs
assert not agent_2.handoffs assert not agent_2.handoffs
handoff_objects = Runner._get_handoffs(agent_3) handoff_objects = AgentRunner._get_handoffs(agent_3)
assert len(handoff_objects) == 2 assert len(handoff_objects) == 2
first_handoff = handoff_objects[0] first_handoff = handoff_objects[0]

View file

@ -10,16 +10,16 @@ from agents import (
AgentOutputSchema, AgentOutputSchema,
AgentOutputSchemaBase, AgentOutputSchemaBase,
ModelBehaviorError, ModelBehaviorError,
Runner,
UserError, UserError,
) )
from agents.agent_output import _WRAPPER_DICT_KEY from agents.agent_output import _WRAPPER_DICT_KEY
from agents.run import AgentRunner
from agents.util import _json from agents.util import _json
def test_plain_text_output(): def test_plain_text_output():
agent = Agent(name="test") agent = Agent(name="test")
output_schema = Runner._get_output_schema(agent) output_schema = AgentRunner._get_output_schema(agent)
assert not output_schema, "Shouldn't have an output tool config without an output type" assert not output_schema, "Shouldn't have an output tool config without an output type"
agent = Agent(name="test", output_type=str) agent = Agent(name="test", output_type=str)
@ -32,7 +32,7 @@ class Foo(BaseModel):
def test_structured_output_pydantic(): def test_structured_output_pydantic():
agent = Agent(name="test", output_type=Foo) agent = Agent(name="test", output_type=Foo)
output_schema = Runner._get_output_schema(agent) output_schema = AgentRunner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type" assert output_schema, "Should have an output tool config with a structured output type"
assert isinstance(output_schema, AgentOutputSchema) assert isinstance(output_schema, AgentOutputSchema)
@ -52,7 +52,7 @@ class Bar(TypedDict):
def test_structured_output_typed_dict(): def test_structured_output_typed_dict():
agent = Agent(name="test", output_type=Bar) agent = Agent(name="test", output_type=Bar)
output_schema = Runner._get_output_schema(agent) output_schema = AgentRunner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type" assert output_schema, "Should have an output tool config with a structured output type"
assert isinstance(output_schema, AgentOutputSchema) assert isinstance(output_schema, AgentOutputSchema)
assert output_schema.output_type == Bar, "Should have the correct output type" assert output_schema.output_type == Bar, "Should have the correct output type"
@ -65,7 +65,7 @@ def test_structured_output_typed_dict():
def test_structured_output_list(): def test_structured_output_list():
agent = Agent(name="test", output_type=list[str]) agent = Agent(name="test", output_type=list[str])
output_schema = Runner._get_output_schema(agent) output_schema = AgentRunner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type" assert output_schema, "Should have an output tool config with a structured output type"
assert isinstance(output_schema, AgentOutputSchema) assert isinstance(output_schema, AgentOutputSchema)
assert output_schema.output_type == list[str], "Should have the correct output type" assert output_schema.output_type == list[str], "Should have the correct output type"
@ -79,14 +79,14 @@ def test_structured_output_list():
def test_bad_json_raises_error(mocker): def test_bad_json_raises_error(mocker):
agent = Agent(name="test", output_type=Foo) agent = Agent(name="test", output_type=Foo)
output_schema = Runner._get_output_schema(agent) output_schema = AgentRunner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type" assert output_schema, "Should have an output tool config with a structured output type"
with pytest.raises(ModelBehaviorError): with pytest.raises(ModelBehaviorError):
output_schema.validate_json("not valid json") output_schema.validate_json("not valid json")
agent = Agent(name="test", output_type=list[str]) agent = Agent(name="test", output_type=list[str])
output_schema = Runner._get_output_schema(agent) output_schema = AgentRunner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type" assert output_schema, "Should have an output tool config with a structured output type"
mock_validate_json = mocker.patch.object(_json, "validate_json") mock_validate_json = mocker.patch.object(_json, "validate_json")
@ -155,7 +155,7 @@ class CustomOutputSchema(AgentOutputSchemaBase):
def test_custom_output_schema(): def test_custom_output_schema():
custom_output_schema = CustomOutputSchema() custom_output_schema = CustomOutputSchema()
agent = Agent(name="test", output_type=custom_output_schema) agent = Agent(name="test", output_type=custom_output_schema)
output_schema = Runner._get_output_schema(agent) output_schema = AgentRunner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type" assert output_schema, "Should have an output tool config with a structured output type"
assert isinstance(output_schema, CustomOutputSchema) assert isinstance(output_schema, CustomOutputSchema)

26
tests/test_run.py Normal file
View file

@ -0,0 +1,26 @@
from __future__ import annotations
from unittest import mock
import pytest
from agents import Agent, Runner
from agents.run import AgentRunner, set_default_agent_runner
from .fake_model import FakeModel
@pytest.mark.asyncio
async def test_static_run_methods_call_into_default_runner() -> None:
runner = mock.Mock(spec=AgentRunner)
set_default_agent_runner(runner)
agent = Agent(name="test", model=FakeModel())
await Runner.run(agent, input="test")
runner.run.assert_called_once()
Runner.run_streamed(agent, input="test")
runner.run_streamed.assert_called_once()
Runner.run_sync(agent, input="test")
runner.run_sync.assert_called_once()

View file

@ -60,7 +60,7 @@ async def test_run_config_model_name_override_takes_precedence() -> None:
async def test_run_config_model_override_object_takes_precedence() -> None: async def test_run_config_model_override_object_takes_precedence() -> None:
""" """
When a concrete Model instance is set on the RunConfig, then that instance should be When a concrete Model instance is set on the RunConfig, then that instance should be
returned by Runner._get_model regardless of the agent's model. returned by AgentRunner._get_model regardless of the agent's model.
""" """
fake_model = FakeModel(initial_output=[get_text_message("override-object")]) fake_model = FakeModel(initial_output=[get_text_message("override-object")])
agent = Agent(name="test", model="agent-model") agent = Agent(name="test", model="agent-model")

View file

@ -14,7 +14,6 @@ from agents import (
RunContextWrapper, RunContextWrapper,
RunHooks, RunHooks,
RunItem, RunItem,
Runner,
ToolCallItem, ToolCallItem,
ToolCallOutputItem, ToolCallOutputItem,
TResponseInputItem, TResponseInputItem,
@ -27,6 +26,7 @@ from agents._run_impl import (
RunImpl, RunImpl,
SingleStepResult, SingleStepResult,
) )
from agents.run import AgentRunner
from agents.tool import function_tool from agents.tool import function_tool
from agents.tool_context import ToolContext from agents.tool_context import ToolContext
@ -324,8 +324,8 @@ async def get_execute_result(
context_wrapper: RunContextWrapper[Any] | None = None, context_wrapper: RunContextWrapper[Any] | None = None,
run_config: RunConfig | None = None, run_config: RunConfig | None = None,
) -> SingleStepResult: ) -> SingleStepResult:
output_schema = Runner._get_output_schema(agent) output_schema = AgentRunner._get_output_schema(agent)
handoffs = Runner._get_handoffs(agent) handoffs = AgentRunner._get_handoffs(agent)
processed_response = RunImpl.process_model_response( processed_response = RunImpl.process_model_response(
agent=agent, agent=agent,

View file

@ -19,11 +19,11 @@ from agents import (
ModelResponse, ModelResponse,
ReasoningItem, ReasoningItem,
RunContextWrapper, RunContextWrapper,
Runner,
ToolCallItem, ToolCallItem,
Usage, Usage,
) )
from agents._run_impl import RunImpl from agents._run_impl import RunImpl
from agents.run import AgentRunner
from .test_responses import ( from .test_responses import (
get_final_output_message, get_final_output_message,
@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly():
agent=agent_3, agent=agent_3,
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=Runner._get_handoffs(agent_3), handoffs=AgentRunner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(_dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
assert len(result.handoffs) == 1, "Should have a handoff here" assert len(result.handoffs) == 1, "Should have a handoff here"
@ -216,7 +216,7 @@ async def test_missing_handoff_fails():
agent=agent_3, agent=agent_3,
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=Runner._get_handoffs(agent_3), handoffs=AgentRunner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(_dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error():
agent=agent_3, agent=agent_3,
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=Runner._get_handoffs(agent_3), handoffs=AgentRunner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(_dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
assert len(result.handoffs) == 2, "Should have multiple handoffs here" assert len(result.handoffs) == 2, "Should have multiple handoffs here"
@ -264,7 +264,7 @@ async def test_final_output_parsed_correctly():
RunImpl.process_model_response( RunImpl.process_model_response(
agent=agent, agent=agent,
response=response, response=response,
output_schema=Runner._get_output_schema(agent), output_schema=AgentRunner._get_output_schema(agent),
handoffs=[], handoffs=[],
all_tools=await agent.get_all_tools(_dummy_ctx()), all_tools=await agent.get_all_tools(_dummy_ctx()),
) )
@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly():
agent=agent_3, agent=agent_3,
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=Runner._get_handoffs(agent_3), handoffs=AgentRunner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(_dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
assert result.functions and len(result.functions) == 1 assert result.functions and len(result.functions) == 1

View file

@ -19,11 +19,12 @@ from agents.items import (
TResponseStreamEvent, TResponseStreamEvent,
) )
from ..fake_model import get_response_obj
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
try: try:
from agents.voice import SingleAgentVoiceWorkflow from agents.voice import SingleAgentVoiceWorkflow
from ..fake_model import get_response_obj
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
except ImportError: except ImportError:
pass pass

3332
uv.lock

File diff suppressed because it is too large Load diff