Allow replacing AgentRunner and TraceProvider (#720)
This commit is contained in:
parent
901d2ac57c
commit
0cf503e1c2
18 changed files with 2290 additions and 2036 deletions
|
|
@ -104,6 +104,7 @@ from .tracing import (
|
|||
handoff_span,
|
||||
mcp_tools_span,
|
||||
set_trace_processors,
|
||||
set_trace_provider,
|
||||
set_tracing_disabled,
|
||||
set_tracing_export_api_key,
|
||||
speech_group_span,
|
||||
|
|
@ -246,6 +247,7 @@ __all__ = [
|
|||
"guardrail_span",
|
||||
"handoff_span",
|
||||
"set_trace_processors",
|
||||
"set_trace_provider",
|
||||
"set_tracing_disabled",
|
||||
"speech_group_span",
|
||||
"transcription_span",
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
from typing import Any, Generic, cast
|
||||
|
||||
from openai.types.responses import ResponseCompletedEvent
|
||||
from openai.types.responses.response_prompt_param import (
|
||||
ResponsePromptParam,
|
||||
)
|
||||
from typing_extensions import NotRequired, TypedDict, Unpack
|
||||
|
||||
from ._run_impl import (
|
||||
AgentToolUseTracker,
|
||||
|
|
@ -31,7 +32,12 @@ from .exceptions import (
|
|||
OutputGuardrailTripwireTriggered,
|
||||
RunErrorDetails,
|
||||
)
|
||||
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
|
||||
from .guardrail import (
|
||||
InputGuardrail,
|
||||
InputGuardrailResult,
|
||||
OutputGuardrail,
|
||||
OutputGuardrailResult,
|
||||
)
|
||||
from .handoffs import Handoff, HandoffInputFilter, handoff
|
||||
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
||||
from .lifecycle import RunHooks
|
||||
|
|
@ -50,6 +56,27 @@ from .util import _coro, _error_tracing
|
|||
|
||||
DEFAULT_MAX_TURNS = 10
|
||||
|
||||
DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore
|
||||
# the value is set at the end of the module
|
||||
|
||||
|
||||
def set_default_agent_runner(runner: AgentRunner | None) -> None:
|
||||
"""
|
||||
WARNING: this class is experimental and not part of the public API
|
||||
It should not be used directly.
|
||||
"""
|
||||
global DEFAULT_AGENT_RUNNER
|
||||
DEFAULT_AGENT_RUNNER = runner or AgentRunner()
|
||||
|
||||
|
||||
def get_default_agent_runner() -> AgentRunner:
|
||||
"""
|
||||
WARNING: this class is experimental and not part of the public API
|
||||
It should not be used directly.
|
||||
"""
|
||||
global DEFAULT_AGENT_RUNNER
|
||||
return DEFAULT_AGENT_RUNNER
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunConfig:
|
||||
|
|
@ -110,6 +137,25 @@ class RunConfig:
|
|||
"""
|
||||
|
||||
|
||||
class RunOptions(TypedDict, Generic[TContext]):
|
||||
"""Arguments for ``AgentRunner`` methods."""
|
||||
|
||||
context: NotRequired[TContext | None]
|
||||
"""The context for the run."""
|
||||
|
||||
max_turns: NotRequired[int]
|
||||
"""The maximum number of turns to run for."""
|
||||
|
||||
hooks: NotRequired[RunHooks[TContext] | None]
|
||||
"""Lifecycle hooks for the run."""
|
||||
|
||||
run_config: NotRequired[RunConfig | None]
|
||||
"""Run configuration."""
|
||||
|
||||
previous_response_id: NotRequired[str | None]
|
||||
"""The ID of the previous response, if any."""
|
||||
|
||||
|
||||
class Runner:
|
||||
@classmethod
|
||||
async def run(
|
||||
|
|
@ -130,13 +176,10 @@ class Runner:
|
|||
`agent.output_type`, the loop terminates.
|
||||
3. If there's a handoff, we run the loop again, with the new agent.
|
||||
4. Else, we run tool calls (if any), and re-run the loop.
|
||||
|
||||
In two cases, the agent may raise an exception:
|
||||
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
||||
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
||||
|
||||
Note that only the first agent's input guardrails are run.
|
||||
|
||||
Args:
|
||||
starting_agent: The starting agent to run.
|
||||
input: The initial input to the agent. You can pass a single string for a user message,
|
||||
|
|
@ -148,11 +191,139 @@ class Runner:
|
|||
run_config: Global settings for the entire agent run.
|
||||
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
||||
Responses API, this allows you to skip passing in input from the previous turn.
|
||||
|
||||
Returns:
|
||||
A run result containing all the inputs, guardrail results and the output of the last
|
||||
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
||||
"""
|
||||
runner = DEFAULT_AGENT_RUNNER
|
||||
return await runner.run(
|
||||
starting_agent,
|
||||
input,
|
||||
context=context,
|
||||
max_turns=max_turns,
|
||||
hooks=hooks,
|
||||
run_config=run_config,
|
||||
previous_response_id=previous_response_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def run_sync(
|
||||
cls,
|
||||
starting_agent: Agent[TContext],
|
||||
input: str | list[TResponseInputItem],
|
||||
*,
|
||||
context: TContext | None = None,
|
||||
max_turns: int = DEFAULT_MAX_TURNS,
|
||||
hooks: RunHooks[TContext] | None = None,
|
||||
run_config: RunConfig | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
) -> RunResult:
|
||||
"""Run a workflow synchronously, starting at the given agent. Note that this just wraps the
|
||||
`run` method, so it will not work if there's already an event loop (e.g. inside an async
|
||||
function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
|
||||
the `run` method instead.
|
||||
The agent will run in a loop until a final output is generated. The loop runs like so:
|
||||
1. The agent is invoked with the given input.
|
||||
2. If there is a final output (i.e. the agent produces something of type
|
||||
`agent.output_type`, the loop terminates.
|
||||
3. If there's a handoff, we run the loop again, with the new agent.
|
||||
4. Else, we run tool calls (if any), and re-run the loop.
|
||||
In two cases, the agent may raise an exception:
|
||||
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
||||
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
||||
Note that only the first agent's input guardrails are run.
|
||||
Args:
|
||||
starting_agent: The starting agent to run.
|
||||
input: The initial input to the agent. You can pass a single string for a user message,
|
||||
or a list of input items.
|
||||
context: The context to run the agent with.
|
||||
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
||||
AI invocation (including any tool calls that might occur).
|
||||
hooks: An object that receives callbacks on various lifecycle events.
|
||||
run_config: Global settings for the entire agent run.
|
||||
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
||||
Responses API, this allows you to skip passing in input from the previous turn.
|
||||
Returns:
|
||||
A run result containing all the inputs, guardrail results and the output of the last
|
||||
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
||||
"""
|
||||
runner = DEFAULT_AGENT_RUNNER
|
||||
return runner.run_sync(
|
||||
starting_agent,
|
||||
input,
|
||||
context=context,
|
||||
max_turns=max_turns,
|
||||
hooks=hooks,
|
||||
run_config=run_config,
|
||||
previous_response_id=previous_response_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def run_streamed(
|
||||
cls,
|
||||
starting_agent: Agent[TContext],
|
||||
input: str | list[TResponseInputItem],
|
||||
context: TContext | None = None,
|
||||
max_turns: int = DEFAULT_MAX_TURNS,
|
||||
hooks: RunHooks[TContext] | None = None,
|
||||
run_config: RunConfig | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
) -> RunResultStreaming:
|
||||
"""Run a workflow starting at the given agent in streaming mode. The returned result object
|
||||
contains a method you can use to stream semantic events as they are generated.
|
||||
The agent will run in a loop until a final output is generated. The loop runs like so:
|
||||
1. The agent is invoked with the given input.
|
||||
2. If there is a final output (i.e. the agent produces something of type
|
||||
`agent.output_type`, the loop terminates.
|
||||
3. If there's a handoff, we run the loop again, with the new agent.
|
||||
4. Else, we run tool calls (if any), and re-run the loop.
|
||||
In two cases, the agent may raise an exception:
|
||||
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
||||
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
||||
Note that only the first agent's input guardrails are run.
|
||||
Args:
|
||||
starting_agent: The starting agent to run.
|
||||
input: The initial input to the agent. You can pass a single string for a user message,
|
||||
or a list of input items.
|
||||
context: The context to run the agent with.
|
||||
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
||||
AI invocation (including any tool calls that might occur).
|
||||
hooks: An object that receives callbacks on various lifecycle events.
|
||||
run_config: Global settings for the entire agent run.
|
||||
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
||||
Responses API, this allows you to skip passing in input from the previous turn.
|
||||
Returns:
|
||||
A result object that contains data about the run, as well as a method to stream events.
|
||||
"""
|
||||
runner = DEFAULT_AGENT_RUNNER
|
||||
return runner.run_streamed(
|
||||
starting_agent,
|
||||
input,
|
||||
context=context,
|
||||
max_turns=max_turns,
|
||||
hooks=hooks,
|
||||
run_config=run_config,
|
||||
previous_response_id=previous_response_id,
|
||||
)
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
"""
|
||||
WARNING: this class is experimental and not part of the public API
|
||||
It should not be used directly or subclassed.
|
||||
"""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
starting_agent: Agent[TContext],
|
||||
input: str | list[TResponseInputItem],
|
||||
**kwargs: Unpack[RunOptions[TContext]],
|
||||
) -> RunResult:
|
||||
context = kwargs.get("context")
|
||||
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
||||
hooks = kwargs.get("hooks")
|
||||
run_config = kwargs.get("run_config")
|
||||
previous_response_id = kwargs.get("previous_response_id")
|
||||
if hooks is None:
|
||||
hooks = RunHooks[Any]()
|
||||
if run_config is None:
|
||||
|
|
@ -184,13 +355,15 @@ class Runner:
|
|||
|
||||
try:
|
||||
while True:
|
||||
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
|
||||
all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper)
|
||||
|
||||
# Start an agent span if we don't have one. This span is ended if the current
|
||||
# agent changes, or if the agent loop ends.
|
||||
if current_span is None:
|
||||
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
|
||||
if output_schema := cls._get_output_schema(current_agent):
|
||||
handoff_names = [
|
||||
h.agent_name for h in AgentRunner._get_handoffs(current_agent)
|
||||
]
|
||||
if output_schema := AgentRunner._get_output_schema(current_agent):
|
||||
output_type_name = output_schema.name()
|
||||
else:
|
||||
output_type_name = "str"
|
||||
|
|
@ -220,14 +393,14 @@ class Runner:
|
|||
|
||||
if current_turn == 1:
|
||||
input_guardrail_results, turn_result = await asyncio.gather(
|
||||
cls._run_input_guardrails(
|
||||
self._run_input_guardrails(
|
||||
starting_agent,
|
||||
starting_agent.input_guardrails
|
||||
+ (run_config.input_guardrails or []),
|
||||
copy.deepcopy(input),
|
||||
context_wrapper,
|
||||
),
|
||||
cls._run_single_turn(
|
||||
self._run_single_turn(
|
||||
agent=current_agent,
|
||||
all_tools=all_tools,
|
||||
original_input=original_input,
|
||||
|
|
@ -241,7 +414,7 @@ class Runner:
|
|||
),
|
||||
)
|
||||
else:
|
||||
turn_result = await cls._run_single_turn(
|
||||
turn_result = await self._run_single_turn(
|
||||
agent=current_agent,
|
||||
all_tools=all_tools,
|
||||
original_input=original_input,
|
||||
|
|
@ -260,7 +433,7 @@ class Runner:
|
|||
generated_items = turn_result.generated_items
|
||||
|
||||
if isinstance(turn_result.next_step, NextStepFinalOutput):
|
||||
output_guardrail_results = await cls._run_output_guardrails(
|
||||
output_guardrail_results = await self._run_output_guardrails(
|
||||
current_agent.output_guardrails + (run_config.output_guardrails or []),
|
||||
current_agent,
|
||||
turn_result.next_step.output,
|
||||
|
|
@ -302,54 +475,19 @@ class Runner:
|
|||
if current_span:
|
||||
current_span.finish(reset_current=True)
|
||||
|
||||
@classmethod
|
||||
def run_sync(
|
||||
cls,
|
||||
self,
|
||||
starting_agent: Agent[TContext],
|
||||
input: str | list[TResponseInputItem],
|
||||
*,
|
||||
context: TContext | None = None,
|
||||
max_turns: int = DEFAULT_MAX_TURNS,
|
||||
hooks: RunHooks[TContext] | None = None,
|
||||
run_config: RunConfig | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
**kwargs: Unpack[RunOptions[TContext]],
|
||||
) -> RunResult:
|
||||
"""Run a workflow synchronously, starting at the given agent. Note that this just wraps the
|
||||
`run` method, so it will not work if there's already an event loop (e.g. inside an async
|
||||
function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
|
||||
the `run` method instead.
|
||||
|
||||
The agent will run in a loop until a final output is generated. The loop runs like so:
|
||||
1. The agent is invoked with the given input.
|
||||
2. If there is a final output (i.e. the agent produces something of type
|
||||
`agent.output_type`, the loop terminates.
|
||||
3. If there's a handoff, we run the loop again, with the new agent.
|
||||
4. Else, we run tool calls (if any), and re-run the loop.
|
||||
|
||||
In two cases, the agent may raise an exception:
|
||||
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
||||
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
||||
|
||||
Note that only the first agent's input guardrails are run.
|
||||
|
||||
Args:
|
||||
starting_agent: The starting agent to run.
|
||||
input: The initial input to the agent. You can pass a single string for a user message,
|
||||
or a list of input items.
|
||||
context: The context to run the agent with.
|
||||
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
||||
AI invocation (including any tool calls that might occur).
|
||||
hooks: An object that receives callbacks on various lifecycle events.
|
||||
run_config: Global settings for the entire agent run.
|
||||
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
||||
Responses API, this allows you to skip passing in input from the previous turn.
|
||||
|
||||
Returns:
|
||||
A run result containing all the inputs, guardrail results and the output of the last
|
||||
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
||||
"""
|
||||
context = kwargs.get("context")
|
||||
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
||||
hooks = kwargs.get("hooks")
|
||||
run_config = kwargs.get("run_config")
|
||||
previous_response_id = kwargs.get("previous_response_id")
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
cls.run(
|
||||
self.run(
|
||||
starting_agent,
|
||||
input,
|
||||
context=context,
|
||||
|
|
@ -360,47 +498,17 @@ class Runner:
|
|||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def run_streamed(
|
||||
cls,
|
||||
self,
|
||||
starting_agent: Agent[TContext],
|
||||
input: str | list[TResponseInputItem],
|
||||
context: TContext | None = None,
|
||||
max_turns: int = DEFAULT_MAX_TURNS,
|
||||
hooks: RunHooks[TContext] | None = None,
|
||||
run_config: RunConfig | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
**kwargs: Unpack[RunOptions[TContext]],
|
||||
) -> RunResultStreaming:
|
||||
"""Run a workflow starting at the given agent in streaming mode. The returned result object
|
||||
contains a method you can use to stream semantic events as they are generated.
|
||||
|
||||
The agent will run in a loop until a final output is generated. The loop runs like so:
|
||||
1. The agent is invoked with the given input.
|
||||
2. If there is a final output (i.e. the agent produces something of type
|
||||
`agent.output_type`, the loop terminates.
|
||||
3. If there's a handoff, we run the loop again, with the new agent.
|
||||
4. Else, we run tool calls (if any), and re-run the loop.
|
||||
|
||||
In two cases, the agent may raise an exception:
|
||||
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
||||
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
||||
|
||||
Note that only the first agent's input guardrails are run.
|
||||
|
||||
Args:
|
||||
starting_agent: The starting agent to run.
|
||||
input: The initial input to the agent. You can pass a single string for a user message,
|
||||
or a list of input items.
|
||||
context: The context to run the agent with.
|
||||
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
||||
AI invocation (including any tool calls that might occur).
|
||||
hooks: An object that receives callbacks on various lifecycle events.
|
||||
run_config: Global settings for the entire agent run.
|
||||
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
||||
Responses API, this allows you to skip passing in input from the previous turn.
|
||||
Returns:
|
||||
A result object that contains data about the run, as well as a method to stream events.
|
||||
"""
|
||||
context = kwargs.get("context")
|
||||
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
||||
hooks = kwargs.get("hooks")
|
||||
run_config = kwargs.get("run_config")
|
||||
previous_response_id = kwargs.get("previous_response_id")
|
||||
if hooks is None:
|
||||
hooks = RunHooks[Any]()
|
||||
if run_config is None:
|
||||
|
|
@ -421,7 +529,7 @@ class Runner:
|
|||
)
|
||||
)
|
||||
|
||||
output_schema = cls._get_output_schema(starting_agent)
|
||||
output_schema = AgentRunner._get_output_schema(starting_agent)
|
||||
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
|
||||
context=context # type: ignore
|
||||
)
|
||||
|
|
@ -444,7 +552,7 @@ class Runner:
|
|||
|
||||
# Kick off the actual agent loop in the background and return the streamed result object.
|
||||
streamed_result._run_impl_task = asyncio.create_task(
|
||||
cls._run_streamed_impl(
|
||||
self._start_streaming(
|
||||
starting_input=input,
|
||||
streamed_result=streamed_result,
|
||||
starting_agent=starting_agent,
|
||||
|
|
@ -501,7 +609,7 @@ class Runner:
|
|||
streamed_result.input_guardrail_results = guardrail_results
|
||||
|
||||
@classmethod
|
||||
async def _run_streamed_impl(
|
||||
async def _start_streaming(
|
||||
cls,
|
||||
starting_input: str | list[TResponseInputItem],
|
||||
streamed_result: RunResultStreaming,
|
||||
|
|
@ -1008,3 +1116,6 @@ class Runner:
|
|||
return agent.model
|
||||
|
||||
return run_config.model_provider.get_model(agent.model)
|
||||
|
||||
|
||||
DEFAULT_AGENT_RUNNER = AgentRunner()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import atexit
|
||||
|
||||
from agents.tracing.provider import DefaultTraceProvider, TraceProvider
|
||||
|
||||
from .create import (
|
||||
agent_span,
|
||||
custom_span,
|
||||
|
|
@ -18,7 +20,7 @@ from .create import (
|
|||
)
|
||||
from .processor_interface import TracingProcessor
|
||||
from .processors import default_exporter, default_processor
|
||||
from .setup import GLOBAL_TRACE_PROVIDER
|
||||
from .setup import get_trace_provider, set_trace_provider
|
||||
from .span_data import (
|
||||
AgentSpanData,
|
||||
CustomSpanData,
|
||||
|
|
@ -45,10 +47,12 @@ __all__ = [
|
|||
"generation_span",
|
||||
"get_current_span",
|
||||
"get_current_trace",
|
||||
"get_trace_provider",
|
||||
"guardrail_span",
|
||||
"handoff_span",
|
||||
"response_span",
|
||||
"set_trace_processors",
|
||||
"set_trace_provider",
|
||||
"set_tracing_disabled",
|
||||
"trace",
|
||||
"Trace",
|
||||
|
|
@ -67,6 +71,7 @@ __all__ = [
|
|||
"SpeechSpanData",
|
||||
"TranscriptionSpanData",
|
||||
"TracingProcessor",
|
||||
"TraceProvider",
|
||||
"gen_trace_id",
|
||||
"gen_span_id",
|
||||
"speech_group_span",
|
||||
|
|
@ -80,21 +85,21 @@ def add_trace_processor(span_processor: TracingProcessor) -> None:
|
|||
"""
|
||||
Adds a new trace processor. This processor will receive all traces/spans.
|
||||
"""
|
||||
GLOBAL_TRACE_PROVIDER.register_processor(span_processor)
|
||||
get_trace_provider().register_processor(span_processor)
|
||||
|
||||
|
||||
def set_trace_processors(processors: list[TracingProcessor]) -> None:
|
||||
"""
|
||||
Set the list of trace processors. This will replace the current list of processors.
|
||||
"""
|
||||
GLOBAL_TRACE_PROVIDER.set_processors(processors)
|
||||
get_trace_provider().set_processors(processors)
|
||||
|
||||
|
||||
def set_tracing_disabled(disabled: bool) -> None:
|
||||
"""
|
||||
Set whether tracing is globally disabled.
|
||||
"""
|
||||
GLOBAL_TRACE_PROVIDER.set_disabled(disabled)
|
||||
get_trace_provider().set_disabled(disabled)
|
||||
|
||||
|
||||
def set_tracing_export_api_key(api_key: str) -> None:
|
||||
|
|
@ -104,10 +109,11 @@ def set_tracing_export_api_key(api_key: str) -> None:
|
|||
default_exporter().set_api_key(api_key)
|
||||
|
||||
|
||||
set_trace_provider(DefaultTraceProvider())
|
||||
# Add the default processor, which exports traces and spans to the backend in batches. You can
|
||||
# change the default behavior by either:
|
||||
# 1. calling add_trace_processor(), which adds additional processors, or
|
||||
# 2. calling set_trace_processors(), which replaces the default processor.
|
||||
add_trace_processor(default_processor())
|
||||
|
||||
atexit.register(GLOBAL_TRACE_PROVIDER.shutdown)
|
||||
atexit.register(get_trace_provider().shutdown)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..logger import logger
|
||||
from .setup import GLOBAL_TRACE_PROVIDER
|
||||
from .setup import get_trace_provider
|
||||
from .span_data import (
|
||||
AgentSpanData,
|
||||
CustomSpanData,
|
||||
|
|
@ -56,13 +56,13 @@ def trace(
|
|||
Returns:
|
||||
The newly created trace object.
|
||||
"""
|
||||
current_trace = GLOBAL_TRACE_PROVIDER.get_current_trace()
|
||||
current_trace = get_trace_provider().get_current_trace()
|
||||
if current_trace:
|
||||
logger.warning(
|
||||
"Trace already exists. Creating a new trace, but this is probably a mistake."
|
||||
)
|
||||
|
||||
return GLOBAL_TRACE_PROVIDER.create_trace(
|
||||
return get_trace_provider().create_trace(
|
||||
name=workflow_name,
|
||||
trace_id=trace_id,
|
||||
group_id=group_id,
|
||||
|
|
@ -73,12 +73,12 @@ def trace(
|
|||
|
||||
def get_current_trace() -> Trace | None:
|
||||
"""Returns the currently active trace, if present."""
|
||||
return GLOBAL_TRACE_PROVIDER.get_current_trace()
|
||||
return get_trace_provider().get_current_trace()
|
||||
|
||||
|
||||
def get_current_span() -> Span[Any] | None:
|
||||
"""Returns the currently active span, if present."""
|
||||
return GLOBAL_TRACE_PROVIDER.get_current_span()
|
||||
return get_trace_provider().get_current_span()
|
||||
|
||||
|
||||
def agent_span(
|
||||
|
|
@ -108,7 +108,7 @@ def agent_span(
|
|||
Returns:
|
||||
The newly created agent span.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
@ -141,7 +141,7 @@ def function_span(
|
|||
Returns:
|
||||
The newly created function span.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=FunctionSpanData(name=name, input=input, output=output),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
@ -183,7 +183,7 @@ def generation_span(
|
|||
Returns:
|
||||
The newly created generation span.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=GenerationSpanData(
|
||||
input=input,
|
||||
output=output,
|
||||
|
|
@ -215,7 +215,7 @@ def response_span(
|
|||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=ResponseSpanData(response=response),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
@ -246,7 +246,7 @@ def handoff_span(
|
|||
Returns:
|
||||
The newly created handoff span.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
@ -278,7 +278,7 @@ def custom_span(
|
|||
Returns:
|
||||
The newly created custom span.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=CustomSpanData(name=name, data=data or {}),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
@ -306,7 +306,7 @@ def guardrail_span(
|
|||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=GuardrailSpanData(name=name, triggered=triggered),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
@ -344,7 +344,7 @@ def transcription_span(
|
|||
Returns:
|
||||
The newly created speech-to-text span.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=TranscriptionSpanData(
|
||||
input=input,
|
||||
input_format=input_format,
|
||||
|
|
@ -386,7 +386,7 @@ def speech_span(
|
|||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=SpeechSpanData(
|
||||
model=model,
|
||||
input=input,
|
||||
|
|
@ -419,7 +419,7 @@ def speech_group_span(
|
|||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=SpeechGroupSpanData(input=input),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
@ -447,7 +447,7 @@ def mcp_tools_span(
|
|||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
"""
|
||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
||||
return get_trace_provider().create_span(
|
||||
span_data=MCPListToolsSpanData(server=server, result=result),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
|
|
|
|||
294
src/agents/tracing/provider.py
Normal file
294
src/agents/tracing/provider.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..logger import logger
|
||||
from .processor_interface import TracingProcessor
|
||||
from .scope import Scope
|
||||
from .spans import NoOpSpan, Span, SpanImpl, TSpanData
|
||||
from .traces import NoOpTrace, Trace, TraceImpl
|
||||
|
||||
|
||||
class SynchronousMultiTracingProcessor(TracingProcessor):
|
||||
"""
|
||||
Forwards all calls to a list of TracingProcessors, in order of registration.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Using a tuple to avoid race conditions when iterating over processors
|
||||
self._processors: tuple[TracingProcessor, ...] = ()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_tracing_processor(self, tracing_processor: TracingProcessor):
|
||||
"""
|
||||
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
||||
"""
|
||||
with self._lock:
|
||||
self._processors += (tracing_processor,)
|
||||
|
||||
def set_processors(self, processors: list[TracingProcessor]):
|
||||
"""
|
||||
Set the list of processors. This will replace the current list of processors.
|
||||
"""
|
||||
with self._lock:
|
||||
self._processors = tuple(processors)
|
||||
|
||||
def on_trace_start(self, trace: Trace) -> None:
|
||||
"""
|
||||
Called when a trace is started.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.on_trace_start(trace)
|
||||
|
||||
def on_trace_end(self, trace: Trace) -> None:
|
||||
"""
|
||||
Called when a trace is finished.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.on_trace_end(trace)
|
||||
|
||||
def on_span_start(self, span: Span[Any]) -> None:
|
||||
"""
|
||||
Called when a span is started.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.on_span_start(span)
|
||||
|
||||
def on_span_end(self, span: Span[Any]) -> None:
|
||||
"""
|
||||
Called when a span is finished.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.on_span_end(span)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called when the application stops.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
logger.debug(f"Shutting down trace processor {processor}")
|
||||
processor.shutdown()
|
||||
|
||||
def force_flush(self):
|
||||
"""
|
||||
Force the processors to flush their buffers.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.force_flush()
|
||||
|
||||
|
||||
class TraceProvider(ABC):
|
||||
"""Interface for creating traces and spans."""
|
||||
|
||||
@abstractmethod
|
||||
def register_processor(self, processor: TracingProcessor) -> None:
|
||||
"""Add a processor that will receive all traces and spans."""
|
||||
|
||||
@abstractmethod
|
||||
def set_processors(self, processors: list[TracingProcessor]) -> None:
|
||||
"""Replace the list of processors with ``processors``."""
|
||||
|
||||
@abstractmethod
|
||||
def get_current_trace(self) -> Trace | None:
|
||||
"""Return the currently active trace, if any."""
|
||||
|
||||
@abstractmethod
|
||||
def get_current_span(self) -> Span[Any] | None:
|
||||
"""Return the currently active span, if any."""
|
||||
|
||||
@abstractmethod
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
"""Enable or disable tracing globally."""
|
||||
|
||||
@abstractmethod
|
||||
def time_iso(self) -> str:
|
||||
"""Return the current time in ISO 8601 format."""
|
||||
|
||||
@abstractmethod
|
||||
def gen_trace_id(self) -> str:
|
||||
"""Generate a new trace identifier."""
|
||||
|
||||
@abstractmethod
|
||||
def gen_span_id(self) -> str:
|
||||
"""Generate a new span identifier."""
|
||||
|
||||
@abstractmethod
|
||||
def gen_group_id(self) -> str:
|
||||
"""Generate a new group identifier."""
|
||||
|
||||
@abstractmethod
|
||||
def create_trace(
|
||||
self,
|
||||
name: str,
|
||||
trace_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Trace:
|
||||
"""Create a new trace."""
|
||||
|
||||
@abstractmethod
|
||||
def create_span(
|
||||
self,
|
||||
span_data: TSpanData,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[TSpanData]:
|
||||
"""Create a new span."""
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Clean up any resources used by the provider."""
|
||||
|
||||
|
||||
class DefaultTraceProvider(TraceProvider):
|
||||
def __init__(self) -> None:
|
||||
self._multi_processor = SynchronousMultiTracingProcessor()
|
||||
self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
)
|
||||
|
||||
def register_processor(self, processor: TracingProcessor):
|
||||
"""
|
||||
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
||||
"""
|
||||
self._multi_processor.add_tracing_processor(processor)
|
||||
|
||||
def set_processors(self, processors: list[TracingProcessor]):
|
||||
"""
|
||||
Set the list of processors. This will replace the current list of processors.
|
||||
"""
|
||||
self._multi_processor.set_processors(processors)
|
||||
|
||||
def get_current_trace(self) -> Trace | None:
|
||||
"""
|
||||
Returns the currently active trace, if any.
|
||||
"""
|
||||
return Scope.get_current_trace()
|
||||
|
||||
def get_current_span(self) -> Span[Any] | None:
|
||||
"""
|
||||
Returns the currently active span, if any.
|
||||
"""
|
||||
return Scope.get_current_span()
|
||||
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
"""
|
||||
Set whether tracing is disabled.
|
||||
"""
|
||||
self._disabled = disabled
|
||||
|
||||
def time_iso(self) -> str:
|
||||
"""Return the current time in ISO 8601 format."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
def gen_trace_id(self) -> str:
|
||||
"""Generate a new trace ID."""
|
||||
return f"trace_{uuid.uuid4().hex}"
|
||||
|
||||
def gen_span_id(self) -> str:
|
||||
"""Generate a new span ID."""
|
||||
return f"span_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def gen_group_id(self) -> str:
|
||||
"""Generate a new group ID."""
|
||||
return f"group_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def create_trace(
|
||||
self,
|
||||
name: str,
|
||||
trace_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Trace:
|
||||
"""
|
||||
Create a new trace.
|
||||
"""
|
||||
if self._disabled or disabled:
|
||||
logger.debug(f"Tracing is disabled. Not creating trace {name}")
|
||||
return NoOpTrace()
|
||||
|
||||
trace_id = trace_id or self.gen_trace_id()
|
||||
|
||||
logger.debug(f"Creating trace {name} with id {trace_id}")
|
||||
|
||||
return TraceImpl(
|
||||
name=name,
|
||||
trace_id=trace_id,
|
||||
group_id=group_id,
|
||||
metadata=metadata,
|
||||
processor=self._multi_processor,
|
||||
)
|
||||
|
||||
def create_span(
|
||||
self,
|
||||
span_data: TSpanData,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[TSpanData]:
|
||||
"""
|
||||
Create a new span.
|
||||
"""
|
||||
if self._disabled or disabled:
|
||||
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
|
||||
return NoOpSpan(span_data)
|
||||
|
||||
if not parent:
|
||||
current_span = Scope.get_current_span()
|
||||
current_trace = Scope.get_current_trace()
|
||||
if current_trace is None:
|
||||
logger.error(
|
||||
"No active trace. Make sure to start a trace with `trace()` first"
|
||||
"Returning NoOpSpan."
|
||||
)
|
||||
return NoOpSpan(span_data)
|
||||
elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan):
|
||||
logger.debug(
|
||||
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
|
||||
)
|
||||
return NoOpSpan(span_data)
|
||||
|
||||
parent_id = current_span.span_id if current_span else None
|
||||
trace_id = current_trace.trace_id
|
||||
|
||||
elif isinstance(parent, Trace):
|
||||
if isinstance(parent, NoOpTrace):
|
||||
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
||||
return NoOpSpan(span_data)
|
||||
trace_id = parent.trace_id
|
||||
parent_id = None
|
||||
elif isinstance(parent, Span):
|
||||
if isinstance(parent, NoOpSpan):
|
||||
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
||||
return NoOpSpan(span_data)
|
||||
parent_id = parent.span_id
|
||||
trace_id = parent.trace_id
|
||||
|
||||
logger.debug(f"Creating span {span_data} with id {span_id}")
|
||||
|
||||
return SpanImpl(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id or self.gen_span_id(),
|
||||
parent_id=parent_id,
|
||||
processor=self._multi_processor,
|
||||
span_data=span_data,
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("Shutting down trace provider")
|
||||
self._multi_processor.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down trace provider: {e}")
|
||||
|
|
@ -1,214 +1,21 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..logger import logger
|
||||
from . import util
|
||||
from .processor_interface import TracingProcessor
|
||||
from .scope import Scope
|
||||
from .spans import NoOpSpan, Span, SpanImpl, TSpanData
|
||||
from .traces import NoOpTrace, Trace, TraceImpl
|
||||
if TYPE_CHECKING:
|
||||
from .provider import TraceProvider
|
||||
|
||||
GLOBAL_TRACE_PROVIDER: TraceProvider | None = None
|
||||
|
||||
|
||||
class SynchronousMultiTracingProcessor(TracingProcessor):
|
||||
"""
|
||||
Forwards all calls to a list of TracingProcessors, in order of registration.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Using a tuple to avoid race conditions when iterating over processors
|
||||
self._processors: tuple[TracingProcessor, ...] = ()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_tracing_processor(self, tracing_processor: TracingProcessor):
|
||||
"""
|
||||
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
||||
"""
|
||||
with self._lock:
|
||||
self._processors += (tracing_processor,)
|
||||
|
||||
def set_processors(self, processors: list[TracingProcessor]):
|
||||
"""
|
||||
Set the list of processors. This will replace the current list of processors.
|
||||
"""
|
||||
with self._lock:
|
||||
self._processors = tuple(processors)
|
||||
|
||||
def on_trace_start(self, trace: Trace) -> None:
|
||||
"""
|
||||
Called when a trace is started.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.on_trace_start(trace)
|
||||
|
||||
def on_trace_end(self, trace: Trace) -> None:
|
||||
"""
|
||||
Called when a trace is finished.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.on_trace_end(trace)
|
||||
|
||||
def on_span_start(self, span: Span[Any]) -> None:
|
||||
"""
|
||||
Called when a span is started.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.on_span_start(span)
|
||||
|
||||
def on_span_end(self, span: Span[Any]) -> None:
|
||||
"""
|
||||
Called when a span is finished.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.on_span_end(span)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called when the application stops.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
logger.debug(f"Shutting down trace processor {processor}")
|
||||
processor.shutdown()
|
||||
|
||||
def force_flush(self):
|
||||
"""
|
||||
Force the processors to flush their buffers.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
processor.force_flush()
|
||||
def set_trace_provider(provider: TraceProvider) -> None:
|
||||
"""Set the global trace provider used by tracing utilities."""
|
||||
global GLOBAL_TRACE_PROVIDER
|
||||
GLOBAL_TRACE_PROVIDER = provider
|
||||
|
||||
|
||||
class TraceProvider:
|
||||
def __init__(self):
|
||||
self._multi_processor = SynchronousMultiTracingProcessor()
|
||||
self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
)
|
||||
|
||||
def register_processor(self, processor: TracingProcessor):
|
||||
"""
|
||||
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
||||
"""
|
||||
self._multi_processor.add_tracing_processor(processor)
|
||||
|
||||
def set_processors(self, processors: list[TracingProcessor]):
|
||||
"""
|
||||
Set the list of processors. This will replace the current list of processors.
|
||||
"""
|
||||
self._multi_processor.set_processors(processors)
|
||||
|
||||
def get_current_trace(self) -> Trace | None:
|
||||
"""
|
||||
Returns the currently active trace, if any.
|
||||
"""
|
||||
return Scope.get_current_trace()
|
||||
|
||||
def get_current_span(self) -> Span[Any] | None:
|
||||
"""
|
||||
Returns the currently active span, if any.
|
||||
"""
|
||||
return Scope.get_current_span()
|
||||
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
"""
|
||||
Set whether tracing is disabled.
|
||||
"""
|
||||
self._disabled = disabled
|
||||
|
||||
def create_trace(
|
||||
self,
|
||||
name: str,
|
||||
trace_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Trace:
|
||||
"""
|
||||
Create a new trace.
|
||||
"""
|
||||
if self._disabled or disabled:
|
||||
logger.debug(f"Tracing is disabled. Not creating trace {name}")
|
||||
return NoOpTrace()
|
||||
|
||||
trace_id = trace_id or util.gen_trace_id()
|
||||
|
||||
logger.debug(f"Creating trace {name} with id {trace_id}")
|
||||
|
||||
return TraceImpl(
|
||||
name=name,
|
||||
trace_id=trace_id,
|
||||
group_id=group_id,
|
||||
metadata=metadata,
|
||||
processor=self._multi_processor,
|
||||
)
|
||||
|
||||
def create_span(
|
||||
self,
|
||||
span_data: TSpanData,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[TSpanData]:
|
||||
"""
|
||||
Create a new span.
|
||||
"""
|
||||
if self._disabled or disabled:
|
||||
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
|
||||
return NoOpSpan(span_data)
|
||||
|
||||
if not parent:
|
||||
current_span = Scope.get_current_span()
|
||||
current_trace = Scope.get_current_trace()
|
||||
if current_trace is None:
|
||||
logger.error(
|
||||
"No active trace. Make sure to start a trace with `trace()` first"
|
||||
"Returning NoOpSpan."
|
||||
)
|
||||
return NoOpSpan(span_data)
|
||||
elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan):
|
||||
logger.debug(
|
||||
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
|
||||
)
|
||||
return NoOpSpan(span_data)
|
||||
|
||||
parent_id = current_span.span_id if current_span else None
|
||||
trace_id = current_trace.trace_id
|
||||
|
||||
elif isinstance(parent, Trace):
|
||||
if isinstance(parent, NoOpTrace):
|
||||
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
||||
return NoOpSpan(span_data)
|
||||
trace_id = parent.trace_id
|
||||
parent_id = None
|
||||
elif isinstance(parent, Span):
|
||||
if isinstance(parent, NoOpSpan):
|
||||
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
||||
return NoOpSpan(span_data)
|
||||
parent_id = parent.span_id
|
||||
trace_id = parent.trace_id
|
||||
|
||||
logger.debug(f"Creating span {span_data} with id {span_id}")
|
||||
|
||||
return SpanImpl(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
parent_id=parent_id,
|
||||
processor=self._multi_processor,
|
||||
span_data=span_data,
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("Shutting down trace provider")
|
||||
self._multi_processor.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down trace provider: {e}")
|
||||
|
||||
|
||||
GLOBAL_TRACE_PROVIDER = TraceProvider()
|
||||
def get_trace_provider() -> TraceProvider:
|
||||
"""Get the global trace provider used by tracing utilities."""
|
||||
if GLOBAL_TRACE_PROVIDER is None:
|
||||
raise RuntimeError("Trace provider not set")
|
||||
return GLOBAL_TRACE_PROVIDER
|
||||
|
|
|
|||
|
|
@ -1,22 +1,21 @@
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from .setup import get_trace_provider
|
||||
|
||||
|
||||
def time_iso() -> str:
|
||||
"""Returns the current time in ISO 8601 format."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
"""Return the current time in ISO 8601 format."""
|
||||
return get_trace_provider().time_iso()
|
||||
|
||||
|
||||
def gen_trace_id() -> str:
|
||||
"""Generates a new trace ID."""
|
||||
return f"trace_{uuid.uuid4().hex}"
|
||||
"""Generate a new trace ID."""
|
||||
return get_trace_provider().gen_trace_id()
|
||||
|
||||
|
||||
def gen_span_id() -> str:
|
||||
"""Generates a new span ID."""
|
||||
return f"span_{uuid.uuid4().hex[:24]}"
|
||||
"""Generate a new span ID."""
|
||||
return get_trace_provider().gen_span_id()
|
||||
|
||||
|
||||
def gen_group_id() -> str:
|
||||
"""Generates a new group ID."""
|
||||
return f"group_{uuid.uuid4().hex[:24]}"
|
||||
"""Generate a new group ID."""
|
||||
return get_trace_provider().gen_group_id()
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@ import pytest
|
|||
from agents.models import _openai_shared
|
||||
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
|
||||
from agents.models.openai_responses import OpenAIResponsesModel
|
||||
from agents.run import set_default_agent_runner
|
||||
from agents.tracing import set_trace_processors
|
||||
from agents.tracing.setup import GLOBAL_TRACE_PROVIDER
|
||||
from agents.tracing.setup import get_trace_provider
|
||||
|
||||
from .testing_processor import SPAN_PROCESSOR_TESTING
|
||||
|
||||
|
|
@ -33,11 +34,16 @@ def clear_openai_settings():
|
|||
_openai_shared._use_responses_by_default = True
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_default_runner():
|
||||
set_default_agent_runner(None)
|
||||
|
||||
|
||||
# This fixture will run after all tests end
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def shutdown_trace_provider():
|
||||
yield
|
||||
GLOBAL_TRACE_PROVIDER.shutdown()
|
||||
get_trace_provider().shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
|||
|
|
@ -1,20 +1,21 @@
|
|||
from agents import Agent, OpenAIResponsesModel, RunConfig, Runner
|
||||
from agents import Agent, OpenAIResponsesModel, RunConfig
|
||||
from agents.extensions.models.litellm_model import LitellmModel
|
||||
from agents.run import AgentRunner
|
||||
|
||||
|
||||
def test_no_prefix_is_openai():
|
||||
agent = Agent(model="gpt-4o", instructions="", name="test")
|
||||
model = Runner._get_model(agent, RunConfig())
|
||||
model = AgentRunner._get_model(agent, RunConfig())
|
||||
assert isinstance(model, OpenAIResponsesModel)
|
||||
|
||||
|
||||
def openai_prefix_is_openai():
|
||||
agent = Agent(model="openai/gpt-4o", instructions="", name="test")
|
||||
model = Runner._get_model(agent, RunConfig())
|
||||
model = AgentRunner._get_model(agent, RunConfig())
|
||||
assert isinstance(model, OpenAIResponsesModel)
|
||||
|
||||
|
||||
def test_litellm_prefix_is_litellm():
|
||||
agent = Agent(model="litellm/foo/bar", instructions="", name="test")
|
||||
model = Runner._get_model(agent, RunConfig())
|
||||
model = AgentRunner._get_model(agent, RunConfig())
|
||||
assert isinstance(model, LitellmModel)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, Runner, handoff
|
||||
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, handoff
|
||||
from agents.run import AgentRunner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -42,7 +43,7 @@ async def test_handoff_with_agents():
|
|||
handoffs=[agent_1, agent_2],
|
||||
)
|
||||
|
||||
handoffs = Runner._get_handoffs(agent_3)
|
||||
handoffs = AgentRunner._get_handoffs(agent_3)
|
||||
assert len(handoffs) == 2
|
||||
|
||||
assert handoffs[0].agent_name == "agent_1"
|
||||
|
|
@ -77,7 +78,7 @@ async def test_handoff_with_handoff_obj():
|
|||
],
|
||||
)
|
||||
|
||||
handoffs = Runner._get_handoffs(agent_3)
|
||||
handoffs = AgentRunner._get_handoffs(agent_3)
|
||||
assert len(handoffs) == 2
|
||||
|
||||
assert handoffs[0].agent_name == "agent_1"
|
||||
|
|
@ -111,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent():
|
|||
handoffs=[handoff(agent_1), agent_2],
|
||||
)
|
||||
|
||||
handoffs = Runner._get_handoffs(agent_3)
|
||||
handoffs = AgentRunner._get_handoffs(agent_3)
|
||||
assert len(handoffs) == 2
|
||||
|
||||
assert handoffs[0].agent_name == "agent_1"
|
||||
|
|
@ -159,7 +160,7 @@ async def test_agent_final_output():
|
|||
output_type=Foo,
|
||||
)
|
||||
|
||||
schema = Runner._get_output_schema(agent)
|
||||
schema = AgentRunner._get_output_schema(agent)
|
||||
assert isinstance(schema, AgentOutputSchema)
|
||||
assert schema is not None
|
||||
assert schema.output_type == Foo
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from agents import (
|
|||
MessageOutputItem,
|
||||
ModelBehaviorError,
|
||||
RunContextWrapper,
|
||||
Runner,
|
||||
UserError,
|
||||
handoff,
|
||||
)
|
||||
from agents.run import AgentRunner
|
||||
|
||||
|
||||
def message_item(content: str, agent: Agent[Any]) -> MessageOutputItem:
|
||||
|
|
@ -45,9 +45,9 @@ def test_single_handoff_setup():
|
|||
assert not agent_1.handoffs
|
||||
assert agent_2.handoffs == [agent_1]
|
||||
|
||||
assert not Runner._get_handoffs(agent_1)
|
||||
assert not AgentRunner._get_handoffs(agent_1)
|
||||
|
||||
handoff_objects = Runner._get_handoffs(agent_2)
|
||||
handoff_objects = AgentRunner._get_handoffs(agent_2)
|
||||
assert len(handoff_objects) == 1
|
||||
obj = handoff_objects[0]
|
||||
assert obj.tool_name == Handoff.default_tool_name(agent_1)
|
||||
|
|
@ -64,7 +64,7 @@ def test_multiple_handoffs_setup():
|
|||
assert not agent_1.handoffs
|
||||
assert not agent_2.handoffs
|
||||
|
||||
handoff_objects = Runner._get_handoffs(agent_3)
|
||||
handoff_objects = AgentRunner._get_handoffs(agent_3)
|
||||
assert len(handoff_objects) == 2
|
||||
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
|
||||
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
|
||||
|
|
@ -95,7 +95,7 @@ def test_custom_handoff_setup():
|
|||
assert not agent_1.handoffs
|
||||
assert not agent_2.handoffs
|
||||
|
||||
handoff_objects = Runner._get_handoffs(agent_3)
|
||||
handoff_objects = AgentRunner._get_handoffs(agent_3)
|
||||
assert len(handoff_objects) == 2
|
||||
|
||||
first_handoff = handoff_objects[0]
|
||||
|
|
|
|||
|
|
@ -10,16 +10,16 @@ from agents import (
|
|||
AgentOutputSchema,
|
||||
AgentOutputSchemaBase,
|
||||
ModelBehaviorError,
|
||||
Runner,
|
||||
UserError,
|
||||
)
|
||||
from agents.agent_output import _WRAPPER_DICT_KEY
|
||||
from agents.run import AgentRunner
|
||||
from agents.util import _json
|
||||
|
||||
|
||||
def test_plain_text_output():
|
||||
agent = Agent(name="test")
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
output_schema = AgentRunner._get_output_schema(agent)
|
||||
assert not output_schema, "Shouldn't have an output tool config without an output type"
|
||||
|
||||
agent = Agent(name="test", output_type=str)
|
||||
|
|
@ -32,7 +32,7 @@ class Foo(BaseModel):
|
|||
|
||||
def test_structured_output_pydantic():
|
||||
agent = Agent(name="test", output_type=Foo)
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
output_schema = AgentRunner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
|
||||
assert isinstance(output_schema, AgentOutputSchema)
|
||||
|
|
@ -52,7 +52,7 @@ class Bar(TypedDict):
|
|||
|
||||
def test_structured_output_typed_dict():
|
||||
agent = Agent(name="test", output_type=Bar)
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
output_schema = AgentRunner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
assert isinstance(output_schema, AgentOutputSchema)
|
||||
assert output_schema.output_type == Bar, "Should have the correct output type"
|
||||
|
|
@ -65,7 +65,7 @@ def test_structured_output_typed_dict():
|
|||
|
||||
def test_structured_output_list():
|
||||
agent = Agent(name="test", output_type=list[str])
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
output_schema = AgentRunner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
assert isinstance(output_schema, AgentOutputSchema)
|
||||
assert output_schema.output_type == list[str], "Should have the correct output type"
|
||||
|
|
@ -79,14 +79,14 @@ def test_structured_output_list():
|
|||
|
||||
def test_bad_json_raises_error(mocker):
|
||||
agent = Agent(name="test", output_type=Foo)
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
output_schema = AgentRunner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
|
||||
with pytest.raises(ModelBehaviorError):
|
||||
output_schema.validate_json("not valid json")
|
||||
|
||||
agent = Agent(name="test", output_type=list[str])
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
output_schema = AgentRunner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
|
||||
mock_validate_json = mocker.patch.object(_json, "validate_json")
|
||||
|
|
@ -155,7 +155,7 @@ class CustomOutputSchema(AgentOutputSchemaBase):
|
|||
def test_custom_output_schema():
|
||||
custom_output_schema = CustomOutputSchema()
|
||||
agent = Agent(name="test", output_type=custom_output_schema)
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
output_schema = AgentRunner._get_output_schema(agent)
|
||||
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
assert isinstance(output_schema, CustomOutputSchema)
|
||||
|
|
|
|||
26
tests/test_run.py
Normal file
26
tests/test_run.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from agents import Agent, Runner
|
||||
from agents.run import AgentRunner, set_default_agent_runner
|
||||
|
||||
from .fake_model import FakeModel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_run_methods_call_into_default_runner() -> None:
|
||||
runner = mock.Mock(spec=AgentRunner)
|
||||
set_default_agent_runner(runner)
|
||||
|
||||
agent = Agent(name="test", model=FakeModel())
|
||||
await Runner.run(agent, input="test")
|
||||
runner.run.assert_called_once()
|
||||
|
||||
Runner.run_streamed(agent, input="test")
|
||||
runner.run_streamed.assert_called_once()
|
||||
|
||||
Runner.run_sync(agent, input="test")
|
||||
runner.run_sync.assert_called_once()
|
||||
|
|
@ -60,7 +60,7 @@ async def test_run_config_model_name_override_takes_precedence() -> None:
|
|||
async def test_run_config_model_override_object_takes_precedence() -> None:
|
||||
"""
|
||||
When a concrete Model instance is set on the RunConfig, then that instance should be
|
||||
returned by Runner._get_model regardless of the agent's model.
|
||||
returned by AgentRunner._get_model regardless of the agent's model.
|
||||
"""
|
||||
fake_model = FakeModel(initial_output=[get_text_message("override-object")])
|
||||
agent = Agent(name="test", model="agent-model")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from agents import (
|
|||
RunContextWrapper,
|
||||
RunHooks,
|
||||
RunItem,
|
||||
Runner,
|
||||
ToolCallItem,
|
||||
ToolCallOutputItem,
|
||||
TResponseInputItem,
|
||||
|
|
@ -27,6 +26,7 @@ from agents._run_impl import (
|
|||
RunImpl,
|
||||
SingleStepResult,
|
||||
)
|
||||
from agents.run import AgentRunner
|
||||
from agents.tool import function_tool
|
||||
from agents.tool_context import ToolContext
|
||||
|
||||
|
|
@ -324,8 +324,8 @@ async def get_execute_result(
|
|||
context_wrapper: RunContextWrapper[Any] | None = None,
|
||||
run_config: RunConfig | None = None,
|
||||
) -> SingleStepResult:
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
handoffs = Runner._get_handoffs(agent)
|
||||
output_schema = AgentRunner._get_output_schema(agent)
|
||||
handoffs = AgentRunner._get_handoffs(agent)
|
||||
|
||||
processed_response = RunImpl.process_model_response(
|
||||
agent=agent,
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ from agents import (
|
|||
ModelResponse,
|
||||
ReasoningItem,
|
||||
RunContextWrapper,
|
||||
Runner,
|
||||
ToolCallItem,
|
||||
Usage,
|
||||
)
|
||||
from agents._run_impl import RunImpl
|
||||
from agents.run import AgentRunner
|
||||
|
||||
from .test_responses import (
|
||||
get_final_output_message,
|
||||
|
|
@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly():
|
|||
agent=agent_3,
|
||||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=Runner._get_handoffs(agent_3),
|
||||
handoffs=AgentRunner._get_handoffs(agent_3),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert len(result.handoffs) == 1, "Should have a handoff here"
|
||||
|
|
@ -216,7 +216,7 @@ async def test_missing_handoff_fails():
|
|||
agent=agent_3,
|
||||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=Runner._get_handoffs(agent_3),
|
||||
handoffs=AgentRunner._get_handoffs(agent_3),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
|
||||
|
|
@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error():
|
|||
agent=agent_3,
|
||||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=Runner._get_handoffs(agent_3),
|
||||
handoffs=AgentRunner._get_handoffs(agent_3),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
|
||||
|
|
@ -264,7 +264,7 @@ async def test_final_output_parsed_correctly():
|
|||
RunImpl.process_model_response(
|
||||
agent=agent,
|
||||
response=response,
|
||||
output_schema=Runner._get_output_schema(agent),
|
||||
output_schema=AgentRunner._get_output_schema(agent),
|
||||
handoffs=[],
|
||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
|
|
@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly():
|
|||
agent=agent_3,
|
||||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=Runner._get_handoffs(agent_3),
|
||||
handoffs=AgentRunner._get_handoffs(agent_3),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert result.functions and len(result.functions) == 1
|
||||
|
|
|
|||
|
|
@ -19,11 +19,12 @@ from agents.items import (
|
|||
TResponseStreamEvent,
|
||||
)
|
||||
|
||||
from ..fake_model import get_response_obj
|
||||
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
|
||||
|
||||
try:
|
||||
from agents.voice import SingleAgentVoiceWorkflow
|
||||
|
||||
from ..fake_model import get_response_obj
|
||||
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue