Allow replacing AgentRunner and TraceProvider (#720)
This commit is contained in:
parent
901d2ac57c
commit
0cf503e1c2
18 changed files with 2290 additions and 2036 deletions
|
|
@ -104,6 +104,7 @@ from .tracing import (
|
||||||
handoff_span,
|
handoff_span,
|
||||||
mcp_tools_span,
|
mcp_tools_span,
|
||||||
set_trace_processors,
|
set_trace_processors,
|
||||||
|
set_trace_provider,
|
||||||
set_tracing_disabled,
|
set_tracing_disabled,
|
||||||
set_tracing_export_api_key,
|
set_tracing_export_api_key,
|
||||||
speech_group_span,
|
speech_group_span,
|
||||||
|
|
@ -246,6 +247,7 @@ __all__ = [
|
||||||
"guardrail_span",
|
"guardrail_span",
|
||||||
"handoff_span",
|
"handoff_span",
|
||||||
"set_trace_processors",
|
"set_trace_processors",
|
||||||
|
"set_trace_provider",
|
||||||
"set_tracing_disabled",
|
"set_tracing_disabled",
|
||||||
"speech_group_span",
|
"speech_group_span",
|
||||||
"transcription_span",
|
"transcription_span",
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,13 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, cast
|
from typing import Any, Generic, cast
|
||||||
|
|
||||||
from openai.types.responses import ResponseCompletedEvent
|
from openai.types.responses import ResponseCompletedEvent
|
||||||
from openai.types.responses.response_prompt_param import (
|
from openai.types.responses.response_prompt_param import (
|
||||||
ResponsePromptParam,
|
ResponsePromptParam,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import NotRequired, TypedDict, Unpack
|
||||||
|
|
||||||
from ._run_impl import (
|
from ._run_impl import (
|
||||||
AgentToolUseTracker,
|
AgentToolUseTracker,
|
||||||
|
|
@ -31,7 +32,12 @@ from .exceptions import (
|
||||||
OutputGuardrailTripwireTriggered,
|
OutputGuardrailTripwireTriggered,
|
||||||
RunErrorDetails,
|
RunErrorDetails,
|
||||||
)
|
)
|
||||||
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
|
from .guardrail import (
|
||||||
|
InputGuardrail,
|
||||||
|
InputGuardrailResult,
|
||||||
|
OutputGuardrail,
|
||||||
|
OutputGuardrailResult,
|
||||||
|
)
|
||||||
from .handoffs import Handoff, HandoffInputFilter, handoff
|
from .handoffs import Handoff, HandoffInputFilter, handoff
|
||||||
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
||||||
from .lifecycle import RunHooks
|
from .lifecycle import RunHooks
|
||||||
|
|
@ -50,6 +56,27 @@ from .util import _coro, _error_tracing
|
||||||
|
|
||||||
DEFAULT_MAX_TURNS = 10
|
DEFAULT_MAX_TURNS = 10
|
||||||
|
|
||||||
|
DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore
|
||||||
|
# the value is set at the end of the module
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_agent_runner(runner: AgentRunner | None) -> None:
|
||||||
|
"""
|
||||||
|
WARNING: this class is experimental and not part of the public API
|
||||||
|
It should not be used directly.
|
||||||
|
"""
|
||||||
|
global DEFAULT_AGENT_RUNNER
|
||||||
|
DEFAULT_AGENT_RUNNER = runner or AgentRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_agent_runner() -> AgentRunner:
|
||||||
|
"""
|
||||||
|
WARNING: this class is experimental and not part of the public API
|
||||||
|
It should not be used directly.
|
||||||
|
"""
|
||||||
|
global DEFAULT_AGENT_RUNNER
|
||||||
|
return DEFAULT_AGENT_RUNNER
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RunConfig:
|
class RunConfig:
|
||||||
|
|
@ -110,6 +137,25 @@ class RunConfig:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class RunOptions(TypedDict, Generic[TContext]):
|
||||||
|
"""Arguments for ``AgentRunner`` methods."""
|
||||||
|
|
||||||
|
context: NotRequired[TContext | None]
|
||||||
|
"""The context for the run."""
|
||||||
|
|
||||||
|
max_turns: NotRequired[int]
|
||||||
|
"""The maximum number of turns to run for."""
|
||||||
|
|
||||||
|
hooks: NotRequired[RunHooks[TContext] | None]
|
||||||
|
"""Lifecycle hooks for the run."""
|
||||||
|
|
||||||
|
run_config: NotRequired[RunConfig | None]
|
||||||
|
"""Run configuration."""
|
||||||
|
|
||||||
|
previous_response_id: NotRequired[str | None]
|
||||||
|
"""The ID of the previous response, if any."""
|
||||||
|
|
||||||
|
|
||||||
class Runner:
|
class Runner:
|
||||||
@classmethod
|
@classmethod
|
||||||
async def run(
|
async def run(
|
||||||
|
|
@ -130,13 +176,10 @@ class Runner:
|
||||||
`agent.output_type`, the loop terminates.
|
`agent.output_type`, the loop terminates.
|
||||||
3. If there's a handoff, we run the loop again, with the new agent.
|
3. If there's a handoff, we run the loop again, with the new agent.
|
||||||
4. Else, we run tool calls (if any), and re-run the loop.
|
4. Else, we run tool calls (if any), and re-run the loop.
|
||||||
|
|
||||||
In two cases, the agent may raise an exception:
|
In two cases, the agent may raise an exception:
|
||||||
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
||||||
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
||||||
|
|
||||||
Note that only the first agent's input guardrails are run.
|
Note that only the first agent's input guardrails are run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
starting_agent: The starting agent to run.
|
starting_agent: The starting agent to run.
|
||||||
input: The initial input to the agent. You can pass a single string for a user message,
|
input: The initial input to the agent. You can pass a single string for a user message,
|
||||||
|
|
@ -148,11 +191,139 @@ class Runner:
|
||||||
run_config: Global settings for the entire agent run.
|
run_config: Global settings for the entire agent run.
|
||||||
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
||||||
Responses API, this allows you to skip passing in input from the previous turn.
|
Responses API, this allows you to skip passing in input from the previous turn.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A run result containing all the inputs, guardrail results and the output of the last
|
A run result containing all the inputs, guardrail results and the output of the last
|
||||||
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
||||||
"""
|
"""
|
||||||
|
runner = DEFAULT_AGENT_RUNNER
|
||||||
|
return await runner.run(
|
||||||
|
starting_agent,
|
||||||
|
input,
|
||||||
|
context=context,
|
||||||
|
max_turns=max_turns,
|
||||||
|
hooks=hooks,
|
||||||
|
run_config=run_config,
|
||||||
|
previous_response_id=previous_response_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run_sync(
|
||||||
|
cls,
|
||||||
|
starting_agent: Agent[TContext],
|
||||||
|
input: str | list[TResponseInputItem],
|
||||||
|
*,
|
||||||
|
context: TContext | None = None,
|
||||||
|
max_turns: int = DEFAULT_MAX_TURNS,
|
||||||
|
hooks: RunHooks[TContext] | None = None,
|
||||||
|
run_config: RunConfig | None = None,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
) -> RunResult:
|
||||||
|
"""Run a workflow synchronously, starting at the given agent. Note that this just wraps the
|
||||||
|
`run` method, so it will not work if there's already an event loop (e.g. inside an async
|
||||||
|
function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
|
||||||
|
the `run` method instead.
|
||||||
|
The agent will run in a loop until a final output is generated. The loop runs like so:
|
||||||
|
1. The agent is invoked with the given input.
|
||||||
|
2. If there is a final output (i.e. the agent produces something of type
|
||||||
|
`agent.output_type`, the loop terminates.
|
||||||
|
3. If there's a handoff, we run the loop again, with the new agent.
|
||||||
|
4. Else, we run tool calls (if any), and re-run the loop.
|
||||||
|
In two cases, the agent may raise an exception:
|
||||||
|
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
||||||
|
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
||||||
|
Note that only the first agent's input guardrails are run.
|
||||||
|
Args:
|
||||||
|
starting_agent: The starting agent to run.
|
||||||
|
input: The initial input to the agent. You can pass a single string for a user message,
|
||||||
|
or a list of input items.
|
||||||
|
context: The context to run the agent with.
|
||||||
|
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
||||||
|
AI invocation (including any tool calls that might occur).
|
||||||
|
hooks: An object that receives callbacks on various lifecycle events.
|
||||||
|
run_config: Global settings for the entire agent run.
|
||||||
|
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
||||||
|
Responses API, this allows you to skip passing in input from the previous turn.
|
||||||
|
Returns:
|
||||||
|
A run result containing all the inputs, guardrail results and the output of the last
|
||||||
|
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
||||||
|
"""
|
||||||
|
runner = DEFAULT_AGENT_RUNNER
|
||||||
|
return runner.run_sync(
|
||||||
|
starting_agent,
|
||||||
|
input,
|
||||||
|
context=context,
|
||||||
|
max_turns=max_turns,
|
||||||
|
hooks=hooks,
|
||||||
|
run_config=run_config,
|
||||||
|
previous_response_id=previous_response_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run_streamed(
|
||||||
|
cls,
|
||||||
|
starting_agent: Agent[TContext],
|
||||||
|
input: str | list[TResponseInputItem],
|
||||||
|
context: TContext | None = None,
|
||||||
|
max_turns: int = DEFAULT_MAX_TURNS,
|
||||||
|
hooks: RunHooks[TContext] | None = None,
|
||||||
|
run_config: RunConfig | None = None,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
) -> RunResultStreaming:
|
||||||
|
"""Run a workflow starting at the given agent in streaming mode. The returned result object
|
||||||
|
contains a method you can use to stream semantic events as they are generated.
|
||||||
|
The agent will run in a loop until a final output is generated. The loop runs like so:
|
||||||
|
1. The agent is invoked with the given input.
|
||||||
|
2. If there is a final output (i.e. the agent produces something of type
|
||||||
|
`agent.output_type`, the loop terminates.
|
||||||
|
3. If there's a handoff, we run the loop again, with the new agent.
|
||||||
|
4. Else, we run tool calls (if any), and re-run the loop.
|
||||||
|
In two cases, the agent may raise an exception:
|
||||||
|
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
||||||
|
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
||||||
|
Note that only the first agent's input guardrails are run.
|
||||||
|
Args:
|
||||||
|
starting_agent: The starting agent to run.
|
||||||
|
input: The initial input to the agent. You can pass a single string for a user message,
|
||||||
|
or a list of input items.
|
||||||
|
context: The context to run the agent with.
|
||||||
|
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
||||||
|
AI invocation (including any tool calls that might occur).
|
||||||
|
hooks: An object that receives callbacks on various lifecycle events.
|
||||||
|
run_config: Global settings for the entire agent run.
|
||||||
|
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
||||||
|
Responses API, this allows you to skip passing in input from the previous turn.
|
||||||
|
Returns:
|
||||||
|
A result object that contains data about the run, as well as a method to stream events.
|
||||||
|
"""
|
||||||
|
runner = DEFAULT_AGENT_RUNNER
|
||||||
|
return runner.run_streamed(
|
||||||
|
starting_agent,
|
||||||
|
input,
|
||||||
|
context=context,
|
||||||
|
max_turns=max_turns,
|
||||||
|
hooks=hooks,
|
||||||
|
run_config=run_config,
|
||||||
|
previous_response_id=previous_response_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRunner:
|
||||||
|
"""
|
||||||
|
WARNING: this class is experimental and not part of the public API
|
||||||
|
It should not be used directly or subclassed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
starting_agent: Agent[TContext],
|
||||||
|
input: str | list[TResponseInputItem],
|
||||||
|
**kwargs: Unpack[RunOptions[TContext]],
|
||||||
|
) -> RunResult:
|
||||||
|
context = kwargs.get("context")
|
||||||
|
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
||||||
|
hooks = kwargs.get("hooks")
|
||||||
|
run_config = kwargs.get("run_config")
|
||||||
|
previous_response_id = kwargs.get("previous_response_id")
|
||||||
if hooks is None:
|
if hooks is None:
|
||||||
hooks = RunHooks[Any]()
|
hooks = RunHooks[Any]()
|
||||||
if run_config is None:
|
if run_config is None:
|
||||||
|
|
@ -184,13 +355,15 @@ class Runner:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
|
all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper)
|
||||||
|
|
||||||
# Start an agent span if we don't have one. This span is ended if the current
|
# Start an agent span if we don't have one. This span is ended if the current
|
||||||
# agent changes, or if the agent loop ends.
|
# agent changes, or if the agent loop ends.
|
||||||
if current_span is None:
|
if current_span is None:
|
||||||
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
|
handoff_names = [
|
||||||
if output_schema := cls._get_output_schema(current_agent):
|
h.agent_name for h in AgentRunner._get_handoffs(current_agent)
|
||||||
|
]
|
||||||
|
if output_schema := AgentRunner._get_output_schema(current_agent):
|
||||||
output_type_name = output_schema.name()
|
output_type_name = output_schema.name()
|
||||||
else:
|
else:
|
||||||
output_type_name = "str"
|
output_type_name = "str"
|
||||||
|
|
@ -220,14 +393,14 @@ class Runner:
|
||||||
|
|
||||||
if current_turn == 1:
|
if current_turn == 1:
|
||||||
input_guardrail_results, turn_result = await asyncio.gather(
|
input_guardrail_results, turn_result = await asyncio.gather(
|
||||||
cls._run_input_guardrails(
|
self._run_input_guardrails(
|
||||||
starting_agent,
|
starting_agent,
|
||||||
starting_agent.input_guardrails
|
starting_agent.input_guardrails
|
||||||
+ (run_config.input_guardrails or []),
|
+ (run_config.input_guardrails or []),
|
||||||
copy.deepcopy(input),
|
copy.deepcopy(input),
|
||||||
context_wrapper,
|
context_wrapper,
|
||||||
),
|
),
|
||||||
cls._run_single_turn(
|
self._run_single_turn(
|
||||||
agent=current_agent,
|
agent=current_agent,
|
||||||
all_tools=all_tools,
|
all_tools=all_tools,
|
||||||
original_input=original_input,
|
original_input=original_input,
|
||||||
|
|
@ -241,7 +414,7 @@ class Runner:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
turn_result = await cls._run_single_turn(
|
turn_result = await self._run_single_turn(
|
||||||
agent=current_agent,
|
agent=current_agent,
|
||||||
all_tools=all_tools,
|
all_tools=all_tools,
|
||||||
original_input=original_input,
|
original_input=original_input,
|
||||||
|
|
@ -260,7 +433,7 @@ class Runner:
|
||||||
generated_items = turn_result.generated_items
|
generated_items = turn_result.generated_items
|
||||||
|
|
||||||
if isinstance(turn_result.next_step, NextStepFinalOutput):
|
if isinstance(turn_result.next_step, NextStepFinalOutput):
|
||||||
output_guardrail_results = await cls._run_output_guardrails(
|
output_guardrail_results = await self._run_output_guardrails(
|
||||||
current_agent.output_guardrails + (run_config.output_guardrails or []),
|
current_agent.output_guardrails + (run_config.output_guardrails or []),
|
||||||
current_agent,
|
current_agent,
|
||||||
turn_result.next_step.output,
|
turn_result.next_step.output,
|
||||||
|
|
@ -302,54 +475,19 @@ class Runner:
|
||||||
if current_span:
|
if current_span:
|
||||||
current_span.finish(reset_current=True)
|
current_span.finish(reset_current=True)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def run_sync(
|
def run_sync(
|
||||||
cls,
|
self,
|
||||||
starting_agent: Agent[TContext],
|
starting_agent: Agent[TContext],
|
||||||
input: str | list[TResponseInputItem],
|
input: str | list[TResponseInputItem],
|
||||||
*,
|
**kwargs: Unpack[RunOptions[TContext]],
|
||||||
context: TContext | None = None,
|
|
||||||
max_turns: int = DEFAULT_MAX_TURNS,
|
|
||||||
hooks: RunHooks[TContext] | None = None,
|
|
||||||
run_config: RunConfig | None = None,
|
|
||||||
previous_response_id: str | None = None,
|
|
||||||
) -> RunResult:
|
) -> RunResult:
|
||||||
"""Run a workflow synchronously, starting at the given agent. Note that this just wraps the
|
context = kwargs.get("context")
|
||||||
`run` method, so it will not work if there's already an event loop (e.g. inside an async
|
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
||||||
function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
|
hooks = kwargs.get("hooks")
|
||||||
the `run` method instead.
|
run_config = kwargs.get("run_config")
|
||||||
|
previous_response_id = kwargs.get("previous_response_id")
|
||||||
The agent will run in a loop until a final output is generated. The loop runs like so:
|
|
||||||
1. The agent is invoked with the given input.
|
|
||||||
2. If there is a final output (i.e. the agent produces something of type
|
|
||||||
`agent.output_type`, the loop terminates.
|
|
||||||
3. If there's a handoff, we run the loop again, with the new agent.
|
|
||||||
4. Else, we run tool calls (if any), and re-run the loop.
|
|
||||||
|
|
||||||
In two cases, the agent may raise an exception:
|
|
||||||
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
|
||||||
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
|
||||||
|
|
||||||
Note that only the first agent's input guardrails are run.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
starting_agent: The starting agent to run.
|
|
||||||
input: The initial input to the agent. You can pass a single string for a user message,
|
|
||||||
or a list of input items.
|
|
||||||
context: The context to run the agent with.
|
|
||||||
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
|
||||||
AI invocation (including any tool calls that might occur).
|
|
||||||
hooks: An object that receives callbacks on various lifecycle events.
|
|
||||||
run_config: Global settings for the entire agent run.
|
|
||||||
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
|
||||||
Responses API, this allows you to skip passing in input from the previous turn.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A run result containing all the inputs, guardrail results and the output of the last
|
|
||||||
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
|
||||||
"""
|
|
||||||
return asyncio.get_event_loop().run_until_complete(
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
cls.run(
|
self.run(
|
||||||
starting_agent,
|
starting_agent,
|
||||||
input,
|
input,
|
||||||
context=context,
|
context=context,
|
||||||
|
|
@ -360,47 +498,17 @@ class Runner:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def run_streamed(
|
def run_streamed(
|
||||||
cls,
|
self,
|
||||||
starting_agent: Agent[TContext],
|
starting_agent: Agent[TContext],
|
||||||
input: str | list[TResponseInputItem],
|
input: str | list[TResponseInputItem],
|
||||||
context: TContext | None = None,
|
**kwargs: Unpack[RunOptions[TContext]],
|
||||||
max_turns: int = DEFAULT_MAX_TURNS,
|
|
||||||
hooks: RunHooks[TContext] | None = None,
|
|
||||||
run_config: RunConfig | None = None,
|
|
||||||
previous_response_id: str | None = None,
|
|
||||||
) -> RunResultStreaming:
|
) -> RunResultStreaming:
|
||||||
"""Run a workflow starting at the given agent in streaming mode. The returned result object
|
context = kwargs.get("context")
|
||||||
contains a method you can use to stream semantic events as they are generated.
|
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
||||||
|
hooks = kwargs.get("hooks")
|
||||||
The agent will run in a loop until a final output is generated. The loop runs like so:
|
run_config = kwargs.get("run_config")
|
||||||
1. The agent is invoked with the given input.
|
previous_response_id = kwargs.get("previous_response_id")
|
||||||
2. If there is a final output (i.e. the agent produces something of type
|
|
||||||
`agent.output_type`, the loop terminates.
|
|
||||||
3. If there's a handoff, we run the loop again, with the new agent.
|
|
||||||
4. Else, we run tool calls (if any), and re-run the loop.
|
|
||||||
|
|
||||||
In two cases, the agent may raise an exception:
|
|
||||||
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
|
||||||
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
|
||||||
|
|
||||||
Note that only the first agent's input guardrails are run.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
starting_agent: The starting agent to run.
|
|
||||||
input: The initial input to the agent. You can pass a single string for a user message,
|
|
||||||
or a list of input items.
|
|
||||||
context: The context to run the agent with.
|
|
||||||
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
|
||||||
AI invocation (including any tool calls that might occur).
|
|
||||||
hooks: An object that receives callbacks on various lifecycle events.
|
|
||||||
run_config: Global settings for the entire agent run.
|
|
||||||
previous_response_id: The ID of the previous response, if using OpenAI models via the
|
|
||||||
Responses API, this allows you to skip passing in input from the previous turn.
|
|
||||||
Returns:
|
|
||||||
A result object that contains data about the run, as well as a method to stream events.
|
|
||||||
"""
|
|
||||||
if hooks is None:
|
if hooks is None:
|
||||||
hooks = RunHooks[Any]()
|
hooks = RunHooks[Any]()
|
||||||
if run_config is None:
|
if run_config is None:
|
||||||
|
|
@ -421,7 +529,7 @@ class Runner:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
output_schema = cls._get_output_schema(starting_agent)
|
output_schema = AgentRunner._get_output_schema(starting_agent)
|
||||||
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
|
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
|
||||||
context=context # type: ignore
|
context=context # type: ignore
|
||||||
)
|
)
|
||||||
|
|
@ -444,7 +552,7 @@ class Runner:
|
||||||
|
|
||||||
# Kick off the actual agent loop in the background and return the streamed result object.
|
# Kick off the actual agent loop in the background and return the streamed result object.
|
||||||
streamed_result._run_impl_task = asyncio.create_task(
|
streamed_result._run_impl_task = asyncio.create_task(
|
||||||
cls._run_streamed_impl(
|
self._start_streaming(
|
||||||
starting_input=input,
|
starting_input=input,
|
||||||
streamed_result=streamed_result,
|
streamed_result=streamed_result,
|
||||||
starting_agent=starting_agent,
|
starting_agent=starting_agent,
|
||||||
|
|
@ -501,7 +609,7 @@ class Runner:
|
||||||
streamed_result.input_guardrail_results = guardrail_results
|
streamed_result.input_guardrail_results = guardrail_results
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _run_streamed_impl(
|
async def _start_streaming(
|
||||||
cls,
|
cls,
|
||||||
starting_input: str | list[TResponseInputItem],
|
starting_input: str | list[TResponseInputItem],
|
||||||
streamed_result: RunResultStreaming,
|
streamed_result: RunResultStreaming,
|
||||||
|
|
@ -1008,3 +1116,6 @@ class Runner:
|
||||||
return agent.model
|
return agent.model
|
||||||
|
|
||||||
return run_config.model_provider.get_model(agent.model)
|
return run_config.model_provider.get_model(agent.model)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_AGENT_RUNNER = AgentRunner()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import atexit
|
import atexit
|
||||||
|
|
||||||
|
from agents.tracing.provider import DefaultTraceProvider, TraceProvider
|
||||||
|
|
||||||
from .create import (
|
from .create import (
|
||||||
agent_span,
|
agent_span,
|
||||||
custom_span,
|
custom_span,
|
||||||
|
|
@ -18,7 +20,7 @@ from .create import (
|
||||||
)
|
)
|
||||||
from .processor_interface import TracingProcessor
|
from .processor_interface import TracingProcessor
|
||||||
from .processors import default_exporter, default_processor
|
from .processors import default_exporter, default_processor
|
||||||
from .setup import GLOBAL_TRACE_PROVIDER
|
from .setup import get_trace_provider, set_trace_provider
|
||||||
from .span_data import (
|
from .span_data import (
|
||||||
AgentSpanData,
|
AgentSpanData,
|
||||||
CustomSpanData,
|
CustomSpanData,
|
||||||
|
|
@ -45,10 +47,12 @@ __all__ = [
|
||||||
"generation_span",
|
"generation_span",
|
||||||
"get_current_span",
|
"get_current_span",
|
||||||
"get_current_trace",
|
"get_current_trace",
|
||||||
|
"get_trace_provider",
|
||||||
"guardrail_span",
|
"guardrail_span",
|
||||||
"handoff_span",
|
"handoff_span",
|
||||||
"response_span",
|
"response_span",
|
||||||
"set_trace_processors",
|
"set_trace_processors",
|
||||||
|
"set_trace_provider",
|
||||||
"set_tracing_disabled",
|
"set_tracing_disabled",
|
||||||
"trace",
|
"trace",
|
||||||
"Trace",
|
"Trace",
|
||||||
|
|
@ -67,6 +71,7 @@ __all__ = [
|
||||||
"SpeechSpanData",
|
"SpeechSpanData",
|
||||||
"TranscriptionSpanData",
|
"TranscriptionSpanData",
|
||||||
"TracingProcessor",
|
"TracingProcessor",
|
||||||
|
"TraceProvider",
|
||||||
"gen_trace_id",
|
"gen_trace_id",
|
||||||
"gen_span_id",
|
"gen_span_id",
|
||||||
"speech_group_span",
|
"speech_group_span",
|
||||||
|
|
@ -80,21 +85,21 @@ def add_trace_processor(span_processor: TracingProcessor) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a new trace processor. This processor will receive all traces/spans.
|
Adds a new trace processor. This processor will receive all traces/spans.
|
||||||
"""
|
"""
|
||||||
GLOBAL_TRACE_PROVIDER.register_processor(span_processor)
|
get_trace_provider().register_processor(span_processor)
|
||||||
|
|
||||||
|
|
||||||
def set_trace_processors(processors: list[TracingProcessor]) -> None:
|
def set_trace_processors(processors: list[TracingProcessor]) -> None:
|
||||||
"""
|
"""
|
||||||
Set the list of trace processors. This will replace the current list of processors.
|
Set the list of trace processors. This will replace the current list of processors.
|
||||||
"""
|
"""
|
||||||
GLOBAL_TRACE_PROVIDER.set_processors(processors)
|
get_trace_provider().set_processors(processors)
|
||||||
|
|
||||||
|
|
||||||
def set_tracing_disabled(disabled: bool) -> None:
|
def set_tracing_disabled(disabled: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Set whether tracing is globally disabled.
|
Set whether tracing is globally disabled.
|
||||||
"""
|
"""
|
||||||
GLOBAL_TRACE_PROVIDER.set_disabled(disabled)
|
get_trace_provider().set_disabled(disabled)
|
||||||
|
|
||||||
|
|
||||||
def set_tracing_export_api_key(api_key: str) -> None:
|
def set_tracing_export_api_key(api_key: str) -> None:
|
||||||
|
|
@ -104,10 +109,11 @@ def set_tracing_export_api_key(api_key: str) -> None:
|
||||||
default_exporter().set_api_key(api_key)
|
default_exporter().set_api_key(api_key)
|
||||||
|
|
||||||
|
|
||||||
|
set_trace_provider(DefaultTraceProvider())
|
||||||
# Add the default processor, which exports traces and spans to the backend in batches. You can
|
# Add the default processor, which exports traces and spans to the backend in batches. You can
|
||||||
# change the default behavior by either:
|
# change the default behavior by either:
|
||||||
# 1. calling add_trace_processor(), which adds additional processors, or
|
# 1. calling add_trace_processor(), which adds additional processors, or
|
||||||
# 2. calling set_trace_processors(), which replaces the default processor.
|
# 2. calling set_trace_processors(), which replaces the default processor.
|
||||||
add_trace_processor(default_processor())
|
add_trace_processor(default_processor())
|
||||||
|
|
||||||
atexit.register(GLOBAL_TRACE_PROVIDER.shutdown)
|
atexit.register(get_trace_provider().shutdown)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from ..logger import logger
|
from ..logger import logger
|
||||||
from .setup import GLOBAL_TRACE_PROVIDER
|
from .setup import get_trace_provider
|
||||||
from .span_data import (
|
from .span_data import (
|
||||||
AgentSpanData,
|
AgentSpanData,
|
||||||
CustomSpanData,
|
CustomSpanData,
|
||||||
|
|
@ -56,13 +56,13 @@ def trace(
|
||||||
Returns:
|
Returns:
|
||||||
The newly created trace object.
|
The newly created trace object.
|
||||||
"""
|
"""
|
||||||
current_trace = GLOBAL_TRACE_PROVIDER.get_current_trace()
|
current_trace = get_trace_provider().get_current_trace()
|
||||||
if current_trace:
|
if current_trace:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Trace already exists. Creating a new trace, but this is probably a mistake."
|
"Trace already exists. Creating a new trace, but this is probably a mistake."
|
||||||
)
|
)
|
||||||
|
|
||||||
return GLOBAL_TRACE_PROVIDER.create_trace(
|
return get_trace_provider().create_trace(
|
||||||
name=workflow_name,
|
name=workflow_name,
|
||||||
trace_id=trace_id,
|
trace_id=trace_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
|
|
@ -73,12 +73,12 @@ def trace(
|
||||||
|
|
||||||
def get_current_trace() -> Trace | None:
|
def get_current_trace() -> Trace | None:
|
||||||
"""Returns the currently active trace, if present."""
|
"""Returns the currently active trace, if present."""
|
||||||
return GLOBAL_TRACE_PROVIDER.get_current_trace()
|
return get_trace_provider().get_current_trace()
|
||||||
|
|
||||||
|
|
||||||
def get_current_span() -> Span[Any] | None:
|
def get_current_span() -> Span[Any] | None:
|
||||||
"""Returns the currently active span, if present."""
|
"""Returns the currently active span, if present."""
|
||||||
return GLOBAL_TRACE_PROVIDER.get_current_span()
|
return get_trace_provider().get_current_span()
|
||||||
|
|
||||||
|
|
||||||
def agent_span(
|
def agent_span(
|
||||||
|
|
@ -108,7 +108,7 @@ def agent_span(
|
||||||
Returns:
|
Returns:
|
||||||
The newly created agent span.
|
The newly created agent span.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type),
|
span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type),
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
|
|
@ -141,7 +141,7 @@ def function_span(
|
||||||
Returns:
|
Returns:
|
||||||
The newly created function span.
|
The newly created function span.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=FunctionSpanData(name=name, input=input, output=output),
|
span_data=FunctionSpanData(name=name, input=input, output=output),
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
|
|
@ -183,7 +183,7 @@ def generation_span(
|
||||||
Returns:
|
Returns:
|
||||||
The newly created generation span.
|
The newly created generation span.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=GenerationSpanData(
|
span_data=GenerationSpanData(
|
||||||
input=input,
|
input=input,
|
||||||
output=output,
|
output=output,
|
||||||
|
|
@ -215,7 +215,7 @@ def response_span(
|
||||||
trace/span as the parent.
|
trace/span as the parent.
|
||||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=ResponseSpanData(response=response),
|
span_data=ResponseSpanData(response=response),
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
|
|
@ -246,7 +246,7 @@ def handoff_span(
|
||||||
Returns:
|
Returns:
|
||||||
The newly created handoff span.
|
The newly created handoff span.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent),
|
span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent),
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
|
|
@ -278,7 +278,7 @@ def custom_span(
|
||||||
Returns:
|
Returns:
|
||||||
The newly created custom span.
|
The newly created custom span.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=CustomSpanData(name=name, data=data or {}),
|
span_data=CustomSpanData(name=name, data=data or {}),
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
|
|
@ -306,7 +306,7 @@ def guardrail_span(
|
||||||
trace/span as the parent.
|
trace/span as the parent.
|
||||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=GuardrailSpanData(name=name, triggered=triggered),
|
span_data=GuardrailSpanData(name=name, triggered=triggered),
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
|
|
@ -344,7 +344,7 @@ def transcription_span(
|
||||||
Returns:
|
Returns:
|
||||||
The newly created speech-to-text span.
|
The newly created speech-to-text span.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=TranscriptionSpanData(
|
span_data=TranscriptionSpanData(
|
||||||
input=input,
|
input=input,
|
||||||
input_format=input_format,
|
input_format=input_format,
|
||||||
|
|
@ -386,7 +386,7 @@ def speech_span(
|
||||||
trace/span as the parent.
|
trace/span as the parent.
|
||||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=SpeechSpanData(
|
span_data=SpeechSpanData(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
|
@ -419,7 +419,7 @@ def speech_group_span(
|
||||||
trace/span as the parent.
|
trace/span as the parent.
|
||||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=SpeechGroupSpanData(input=input),
|
span_data=SpeechGroupSpanData(input=input),
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
|
|
@ -447,7 +447,7 @@ def mcp_tools_span(
|
||||||
trace/span as the parent.
|
trace/span as the parent.
|
||||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||||
"""
|
"""
|
||||||
return GLOBAL_TRACE_PROVIDER.create_span(
|
return get_trace_provider().create_span(
|
||||||
span_data=MCPListToolsSpanData(server=server, result=result),
|
span_data=MCPListToolsSpanData(server=server, result=result),
|
||||||
span_id=span_id,
|
span_id=span_id,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
|
|
|
||||||
294
src/agents/tracing/provider.py
Normal file
294
src/agents/tracing/provider.py
Normal file
|
|
@ -0,0 +1,294 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..logger import logger
|
||||||
|
from .processor_interface import TracingProcessor
|
||||||
|
from .scope import Scope
|
||||||
|
from .spans import NoOpSpan, Span, SpanImpl, TSpanData
|
||||||
|
from .traces import NoOpTrace, Trace, TraceImpl
|
||||||
|
|
||||||
|
|
||||||
|
class SynchronousMultiTracingProcessor(TracingProcessor):
|
||||||
|
"""
|
||||||
|
Forwards all calls to a list of TracingProcessors, in order of registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Using a tuple to avoid race conditions when iterating over processors
|
||||||
|
self._processors: tuple[TracingProcessor, ...] = ()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def add_tracing_processor(self, tracing_processor: TracingProcessor):
|
||||||
|
"""
|
||||||
|
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
self._processors += (tracing_processor,)
|
||||||
|
|
||||||
|
def set_processors(self, processors: list[TracingProcessor]):
|
||||||
|
"""
|
||||||
|
Set the list of processors. This will replace the current list of processors.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
self._processors = tuple(processors)
|
||||||
|
|
||||||
|
def on_trace_start(self, trace: Trace) -> None:
|
||||||
|
"""
|
||||||
|
Called when a trace is started.
|
||||||
|
"""
|
||||||
|
for processor in self._processors:
|
||||||
|
processor.on_trace_start(trace)
|
||||||
|
|
||||||
|
def on_trace_end(self, trace: Trace) -> None:
|
||||||
|
"""
|
||||||
|
Called when a trace is finished.
|
||||||
|
"""
|
||||||
|
for processor in self._processors:
|
||||||
|
processor.on_trace_end(trace)
|
||||||
|
|
||||||
|
def on_span_start(self, span: Span[Any]) -> None:
|
||||||
|
"""
|
||||||
|
Called when a span is started.
|
||||||
|
"""
|
||||||
|
for processor in self._processors:
|
||||||
|
processor.on_span_start(span)
|
||||||
|
|
||||||
|
def on_span_end(self, span: Span[Any]) -> None:
|
||||||
|
"""
|
||||||
|
Called when a span is finished.
|
||||||
|
"""
|
||||||
|
for processor in self._processors:
|
||||||
|
processor.on_span_end(span)
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""
|
||||||
|
Called when the application stops.
|
||||||
|
"""
|
||||||
|
for processor in self._processors:
|
||||||
|
logger.debug(f"Shutting down trace processor {processor}")
|
||||||
|
processor.shutdown()
|
||||||
|
|
||||||
|
def force_flush(self):
|
||||||
|
"""
|
||||||
|
Force the processors to flush their buffers.
|
||||||
|
"""
|
||||||
|
for processor in self._processors:
|
||||||
|
processor.force_flush()
|
||||||
|
|
||||||
|
|
||||||
|
class TraceProvider(ABC):
|
||||||
|
"""Interface for creating traces and spans."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def register_processor(self, processor: TracingProcessor) -> None:
|
||||||
|
"""Add a processor that will receive all traces and spans."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_processors(self, processors: list[TracingProcessor]) -> None:
|
||||||
|
"""Replace the list of processors with ``processors``."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_current_trace(self) -> Trace | None:
|
||||||
|
"""Return the currently active trace, if any."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_current_span(self) -> Span[Any] | None:
|
||||||
|
"""Return the currently active span, if any."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_disabled(self, disabled: bool) -> None:
|
||||||
|
"""Enable or disable tracing globally."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def time_iso(self) -> str:
|
||||||
|
"""Return the current time in ISO 8601 format."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def gen_trace_id(self) -> str:
|
||||||
|
"""Generate a new trace identifier."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def gen_span_id(self) -> str:
|
||||||
|
"""Generate a new span identifier."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def gen_group_id(self) -> str:
|
||||||
|
"""Generate a new group identifier."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_trace(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
) -> Trace:
|
||||||
|
"""Create a new trace."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_span(
|
||||||
|
self,
|
||||||
|
span_data: TSpanData,
|
||||||
|
span_id: str | None = None,
|
||||||
|
parent: Trace | Span[Any] | None = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
) -> Span[TSpanData]:
|
||||||
|
"""Create a new span."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Clean up any resources used by the provider."""
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultTraceProvider(TraceProvider):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._multi_processor = SynchronousMultiTracingProcessor()
|
||||||
|
self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in (
|
||||||
|
"true",
|
||||||
|
"1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_processor(self, processor: TracingProcessor):
|
||||||
|
"""
|
||||||
|
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
||||||
|
"""
|
||||||
|
self._multi_processor.add_tracing_processor(processor)
|
||||||
|
|
||||||
|
def set_processors(self, processors: list[TracingProcessor]):
|
||||||
|
"""
|
||||||
|
Set the list of processors. This will replace the current list of processors.
|
||||||
|
"""
|
||||||
|
self._multi_processor.set_processors(processors)
|
||||||
|
|
||||||
|
def get_current_trace(self) -> Trace | None:
|
||||||
|
"""
|
||||||
|
Returns the currently active trace, if any.
|
||||||
|
"""
|
||||||
|
return Scope.get_current_trace()
|
||||||
|
|
||||||
|
def get_current_span(self) -> Span[Any] | None:
|
||||||
|
"""
|
||||||
|
Returns the currently active span, if any.
|
||||||
|
"""
|
||||||
|
return Scope.get_current_span()
|
||||||
|
|
||||||
|
def set_disabled(self, disabled: bool) -> None:
|
||||||
|
"""
|
||||||
|
Set whether tracing is disabled.
|
||||||
|
"""
|
||||||
|
self._disabled = disabled
|
||||||
|
|
||||||
|
def time_iso(self) -> str:
|
||||||
|
"""Return the current time in ISO 8601 format."""
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
def gen_trace_id(self) -> str:
|
||||||
|
"""Generate a new trace ID."""
|
||||||
|
return f"trace_{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
def gen_span_id(self) -> str:
|
||||||
|
"""Generate a new span ID."""
|
||||||
|
return f"span_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
def gen_group_id(self) -> str:
|
||||||
|
"""Generate a new group ID."""
|
||||||
|
return f"group_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
def create_trace(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
) -> Trace:
|
||||||
|
"""
|
||||||
|
Create a new trace.
|
||||||
|
"""
|
||||||
|
if self._disabled or disabled:
|
||||||
|
logger.debug(f"Tracing is disabled. Not creating trace {name}")
|
||||||
|
return NoOpTrace()
|
||||||
|
|
||||||
|
trace_id = trace_id or self.gen_trace_id()
|
||||||
|
|
||||||
|
logger.debug(f"Creating trace {name} with id {trace_id}")
|
||||||
|
|
||||||
|
return TraceImpl(
|
||||||
|
name=name,
|
||||||
|
trace_id=trace_id,
|
||||||
|
group_id=group_id,
|
||||||
|
metadata=metadata,
|
||||||
|
processor=self._multi_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_span(
|
||||||
|
self,
|
||||||
|
span_data: TSpanData,
|
||||||
|
span_id: str | None = None,
|
||||||
|
parent: Trace | Span[Any] | None = None,
|
||||||
|
disabled: bool = False,
|
||||||
|
) -> Span[TSpanData]:
|
||||||
|
"""
|
||||||
|
Create a new span.
|
||||||
|
"""
|
||||||
|
if self._disabled or disabled:
|
||||||
|
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
|
||||||
|
return NoOpSpan(span_data)
|
||||||
|
|
||||||
|
if not parent:
|
||||||
|
current_span = Scope.get_current_span()
|
||||||
|
current_trace = Scope.get_current_trace()
|
||||||
|
if current_trace is None:
|
||||||
|
logger.error(
|
||||||
|
"No active trace. Make sure to start a trace with `trace()` first"
|
||||||
|
"Returning NoOpSpan."
|
||||||
|
)
|
||||||
|
return NoOpSpan(span_data)
|
||||||
|
elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan):
|
||||||
|
logger.debug(
|
||||||
|
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
|
||||||
|
)
|
||||||
|
return NoOpSpan(span_data)
|
||||||
|
|
||||||
|
parent_id = current_span.span_id if current_span else None
|
||||||
|
trace_id = current_trace.trace_id
|
||||||
|
|
||||||
|
elif isinstance(parent, Trace):
|
||||||
|
if isinstance(parent, NoOpTrace):
|
||||||
|
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
||||||
|
return NoOpSpan(span_data)
|
||||||
|
trace_id = parent.trace_id
|
||||||
|
parent_id = None
|
||||||
|
elif isinstance(parent, Span):
|
||||||
|
if isinstance(parent, NoOpSpan):
|
||||||
|
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
||||||
|
return NoOpSpan(span_data)
|
||||||
|
parent_id = parent.span_id
|
||||||
|
trace_id = parent.trace_id
|
||||||
|
|
||||||
|
logger.debug(f"Creating span {span_data} with id {span_id}")
|
||||||
|
|
||||||
|
return SpanImpl(
|
||||||
|
trace_id=trace_id,
|
||||||
|
span_id=span_id or self.gen_span_id(),
|
||||||
|
parent_id=parent_id,
|
||||||
|
processor=self._multi_processor,
|
||||||
|
span_data=span_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
if self._disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("Shutting down trace provider")
|
||||||
|
self._multi_processor.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error shutting down trace provider: {e}")
|
||||||
|
|
@ -1,214 +1,21 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
from typing import TYPE_CHECKING
|
||||||
import threading
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from ..logger import logger
|
if TYPE_CHECKING:
|
||||||
from . import util
|
from .provider import TraceProvider
|
||||||
from .processor_interface import TracingProcessor
|
|
||||||
from .scope import Scope
|
GLOBAL_TRACE_PROVIDER: TraceProvider | None = None
|
||||||
from .spans import NoOpSpan, Span, SpanImpl, TSpanData
|
|
||||||
from .traces import NoOpTrace, Trace, TraceImpl
|
|
||||||
|
|
||||||
|
|
||||||
class SynchronousMultiTracingProcessor(TracingProcessor):
|
def set_trace_provider(provider: TraceProvider) -> None:
|
||||||
"""
|
"""Set the global trace provider used by tracing utilities."""
|
||||||
Forwards all calls to a list of TracingProcessors, in order of registration.
|
global GLOBAL_TRACE_PROVIDER
|
||||||
"""
|
GLOBAL_TRACE_PROVIDER = provider
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Using a tuple to avoid race conditions when iterating over processors
|
|
||||||
self._processors: tuple[TracingProcessor, ...] = ()
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def add_tracing_processor(self, tracing_processor: TracingProcessor):
|
|
||||||
"""
|
|
||||||
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
self._processors += (tracing_processor,)
|
|
||||||
|
|
||||||
def set_processors(self, processors: list[TracingProcessor]):
|
|
||||||
"""
|
|
||||||
Set the list of processors. This will replace the current list of processors.
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
self._processors = tuple(processors)
|
|
||||||
|
|
||||||
def on_trace_start(self, trace: Trace) -> None:
|
|
||||||
"""
|
|
||||||
Called when a trace is started.
|
|
||||||
"""
|
|
||||||
for processor in self._processors:
|
|
||||||
processor.on_trace_start(trace)
|
|
||||||
|
|
||||||
def on_trace_end(self, trace: Trace) -> None:
|
|
||||||
"""
|
|
||||||
Called when a trace is finished.
|
|
||||||
"""
|
|
||||||
for processor in self._processors:
|
|
||||||
processor.on_trace_end(trace)
|
|
||||||
|
|
||||||
def on_span_start(self, span: Span[Any]) -> None:
|
|
||||||
"""
|
|
||||||
Called when a span is started.
|
|
||||||
"""
|
|
||||||
for processor in self._processors:
|
|
||||||
processor.on_span_start(span)
|
|
||||||
|
|
||||||
def on_span_end(self, span: Span[Any]) -> None:
|
|
||||||
"""
|
|
||||||
Called when a span is finished.
|
|
||||||
"""
|
|
||||||
for processor in self._processors:
|
|
||||||
processor.on_span_end(span)
|
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
|
||||||
"""
|
|
||||||
Called when the application stops.
|
|
||||||
"""
|
|
||||||
for processor in self._processors:
|
|
||||||
logger.debug(f"Shutting down trace processor {processor}")
|
|
||||||
processor.shutdown()
|
|
||||||
|
|
||||||
def force_flush(self):
|
|
||||||
"""
|
|
||||||
Force the processors to flush their buffers.
|
|
||||||
"""
|
|
||||||
for processor in self._processors:
|
|
||||||
processor.force_flush()
|
|
||||||
|
|
||||||
|
|
||||||
class TraceProvider:
|
def get_trace_provider() -> TraceProvider:
|
||||||
def __init__(self):
|
"""Get the global trace provider used by tracing utilities."""
|
||||||
self._multi_processor = SynchronousMultiTracingProcessor()
|
if GLOBAL_TRACE_PROVIDER is None:
|
||||||
self._disabled = os.environ.get("OPENAI_AGENTS_DISABLE_TRACING", "false").lower() in (
|
raise RuntimeError("Trace provider not set")
|
||||||
"true",
|
return GLOBAL_TRACE_PROVIDER
|
||||||
"1",
|
|
||||||
)
|
|
||||||
|
|
||||||
def register_processor(self, processor: TracingProcessor):
|
|
||||||
"""
|
|
||||||
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
|
||||||
"""
|
|
||||||
self._multi_processor.add_tracing_processor(processor)
|
|
||||||
|
|
||||||
def set_processors(self, processors: list[TracingProcessor]):
|
|
||||||
"""
|
|
||||||
Set the list of processors. This will replace the current list of processors.
|
|
||||||
"""
|
|
||||||
self._multi_processor.set_processors(processors)
|
|
||||||
|
|
||||||
def get_current_trace(self) -> Trace | None:
|
|
||||||
"""
|
|
||||||
Returns the currently active trace, if any.
|
|
||||||
"""
|
|
||||||
return Scope.get_current_trace()
|
|
||||||
|
|
||||||
def get_current_span(self) -> Span[Any] | None:
|
|
||||||
"""
|
|
||||||
Returns the currently active span, if any.
|
|
||||||
"""
|
|
||||||
return Scope.get_current_span()
|
|
||||||
|
|
||||||
def set_disabled(self, disabled: bool) -> None:
|
|
||||||
"""
|
|
||||||
Set whether tracing is disabled.
|
|
||||||
"""
|
|
||||||
self._disabled = disabled
|
|
||||||
|
|
||||||
def create_trace(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
group_id: str | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
disabled: bool = False,
|
|
||||||
) -> Trace:
|
|
||||||
"""
|
|
||||||
Create a new trace.
|
|
||||||
"""
|
|
||||||
if self._disabled or disabled:
|
|
||||||
logger.debug(f"Tracing is disabled. Not creating trace {name}")
|
|
||||||
return NoOpTrace()
|
|
||||||
|
|
||||||
trace_id = trace_id or util.gen_trace_id()
|
|
||||||
|
|
||||||
logger.debug(f"Creating trace {name} with id {trace_id}")
|
|
||||||
|
|
||||||
return TraceImpl(
|
|
||||||
name=name,
|
|
||||||
trace_id=trace_id,
|
|
||||||
group_id=group_id,
|
|
||||||
metadata=metadata,
|
|
||||||
processor=self._multi_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_span(
|
|
||||||
self,
|
|
||||||
span_data: TSpanData,
|
|
||||||
span_id: str | None = None,
|
|
||||||
parent: Trace | Span[Any] | None = None,
|
|
||||||
disabled: bool = False,
|
|
||||||
) -> Span[TSpanData]:
|
|
||||||
"""
|
|
||||||
Create a new span.
|
|
||||||
"""
|
|
||||||
if self._disabled or disabled:
|
|
||||||
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
|
|
||||||
return NoOpSpan(span_data)
|
|
||||||
|
|
||||||
if not parent:
|
|
||||||
current_span = Scope.get_current_span()
|
|
||||||
current_trace = Scope.get_current_trace()
|
|
||||||
if current_trace is None:
|
|
||||||
logger.error(
|
|
||||||
"No active trace. Make sure to start a trace with `trace()` first"
|
|
||||||
"Returning NoOpSpan."
|
|
||||||
)
|
|
||||||
return NoOpSpan(span_data)
|
|
||||||
elif isinstance(current_trace, NoOpTrace) or isinstance(current_span, NoOpSpan):
|
|
||||||
logger.debug(
|
|
||||||
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
|
|
||||||
)
|
|
||||||
return NoOpSpan(span_data)
|
|
||||||
|
|
||||||
parent_id = current_span.span_id if current_span else None
|
|
||||||
trace_id = current_trace.trace_id
|
|
||||||
|
|
||||||
elif isinstance(parent, Trace):
|
|
||||||
if isinstance(parent, NoOpTrace):
|
|
||||||
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
|
||||||
return NoOpSpan(span_data)
|
|
||||||
trace_id = parent.trace_id
|
|
||||||
parent_id = None
|
|
||||||
elif isinstance(parent, Span):
|
|
||||||
if isinstance(parent, NoOpSpan):
|
|
||||||
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
|
||||||
return NoOpSpan(span_data)
|
|
||||||
parent_id = parent.span_id
|
|
||||||
trace_id = parent.trace_id
|
|
||||||
|
|
||||||
logger.debug(f"Creating span {span_data} with id {span_id}")
|
|
||||||
|
|
||||||
return SpanImpl(
|
|
||||||
trace_id=trace_id,
|
|
||||||
span_id=span_id,
|
|
||||||
parent_id=parent_id,
|
|
||||||
processor=self._multi_processor,
|
|
||||||
span_data=span_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
|
||||||
if self._disabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.debug("Shutting down trace provider")
|
|
||||||
self._multi_processor.shutdown()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error shutting down trace provider: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
GLOBAL_TRACE_PROVIDER = TraceProvider()
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
import uuid
|
from .setup import get_trace_provider
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
|
|
||||||
def time_iso() -> str:
|
def time_iso() -> str:
|
||||||
"""Returns the current time in ISO 8601 format."""
|
"""Return the current time in ISO 8601 format."""
|
||||||
return datetime.now(timezone.utc).isoformat()
|
return get_trace_provider().time_iso()
|
||||||
|
|
||||||
|
|
||||||
def gen_trace_id() -> str:
|
def gen_trace_id() -> str:
|
||||||
"""Generates a new trace ID."""
|
"""Generate a new trace ID."""
|
||||||
return f"trace_{uuid.uuid4().hex}"
|
return get_trace_provider().gen_trace_id()
|
||||||
|
|
||||||
|
|
||||||
def gen_span_id() -> str:
|
def gen_span_id() -> str:
|
||||||
"""Generates a new span ID."""
|
"""Generate a new span ID."""
|
||||||
return f"span_{uuid.uuid4().hex[:24]}"
|
return get_trace_provider().gen_span_id()
|
||||||
|
|
||||||
|
|
||||||
def gen_group_id() -> str:
|
def gen_group_id() -> str:
|
||||||
"""Generates a new group ID."""
|
"""Generate a new group ID."""
|
||||||
return f"group_{uuid.uuid4().hex[:24]}"
|
return get_trace_provider().gen_group_id()
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,9 @@ import pytest
|
||||||
from agents.models import _openai_shared
|
from agents.models import _openai_shared
|
||||||
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
|
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
|
||||||
from agents.models.openai_responses import OpenAIResponsesModel
|
from agents.models.openai_responses import OpenAIResponsesModel
|
||||||
|
from agents.run import set_default_agent_runner
|
||||||
from agents.tracing import set_trace_processors
|
from agents.tracing import set_trace_processors
|
||||||
from agents.tracing.setup import GLOBAL_TRACE_PROVIDER
|
from agents.tracing.setup import get_trace_provider
|
||||||
|
|
||||||
from .testing_processor import SPAN_PROCESSOR_TESTING
|
from .testing_processor import SPAN_PROCESSOR_TESTING
|
||||||
|
|
||||||
|
|
@ -33,11 +34,16 @@ def clear_openai_settings():
|
||||||
_openai_shared._use_responses_by_default = True
|
_openai_shared._use_responses_by_default = True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_default_runner():
|
||||||
|
set_default_agent_runner(None)
|
||||||
|
|
||||||
|
|
||||||
# This fixture will run after all tests end
|
# This fixture will run after all tests end
|
||||||
@pytest.fixture(autouse=True, scope="session")
|
@pytest.fixture(autouse=True, scope="session")
|
||||||
def shutdown_trace_provider():
|
def shutdown_trace_provider():
|
||||||
yield
|
yield
|
||||||
GLOBAL_TRACE_PROVIDER.shutdown()
|
get_trace_provider().shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,21 @@
|
||||||
from agents import Agent, OpenAIResponsesModel, RunConfig, Runner
|
from agents import Agent, OpenAIResponsesModel, RunConfig
|
||||||
from agents.extensions.models.litellm_model import LitellmModel
|
from agents.extensions.models.litellm_model import LitellmModel
|
||||||
|
from agents.run import AgentRunner
|
||||||
|
|
||||||
|
|
||||||
def test_no_prefix_is_openai():
|
def test_no_prefix_is_openai():
|
||||||
agent = Agent(model="gpt-4o", instructions="", name="test")
|
agent = Agent(model="gpt-4o", instructions="", name="test")
|
||||||
model = Runner._get_model(agent, RunConfig())
|
model = AgentRunner._get_model(agent, RunConfig())
|
||||||
assert isinstance(model, OpenAIResponsesModel)
|
assert isinstance(model, OpenAIResponsesModel)
|
||||||
|
|
||||||
|
|
||||||
def openai_prefix_is_openai():
|
def openai_prefix_is_openai():
|
||||||
agent = Agent(model="openai/gpt-4o", instructions="", name="test")
|
agent = Agent(model="openai/gpt-4o", instructions="", name="test")
|
||||||
model = Runner._get_model(agent, RunConfig())
|
model = AgentRunner._get_model(agent, RunConfig())
|
||||||
assert isinstance(model, OpenAIResponsesModel)
|
assert isinstance(model, OpenAIResponsesModel)
|
||||||
|
|
||||||
|
|
||||||
def test_litellm_prefix_is_litellm():
|
def test_litellm_prefix_is_litellm():
|
||||||
agent = Agent(model="litellm/foo/bar", instructions="", name="test")
|
agent = Agent(model="litellm/foo/bar", instructions="", name="test")
|
||||||
model = Runner._get_model(agent, RunConfig())
|
model = AgentRunner._get_model(agent, RunConfig())
|
||||||
assert isinstance(model, LitellmModel)
|
assert isinstance(model, LitellmModel)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, Runner, handoff
|
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, handoff
|
||||||
|
from agents.run import AgentRunner
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -42,7 +43,7 @@ async def test_handoff_with_agents():
|
||||||
handoffs=[agent_1, agent_2],
|
handoffs=[agent_1, agent_2],
|
||||||
)
|
)
|
||||||
|
|
||||||
handoffs = Runner._get_handoffs(agent_3)
|
handoffs = AgentRunner._get_handoffs(agent_3)
|
||||||
assert len(handoffs) == 2
|
assert len(handoffs) == 2
|
||||||
|
|
||||||
assert handoffs[0].agent_name == "agent_1"
|
assert handoffs[0].agent_name == "agent_1"
|
||||||
|
|
@ -77,7 +78,7 @@ async def test_handoff_with_handoff_obj():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
handoffs = Runner._get_handoffs(agent_3)
|
handoffs = AgentRunner._get_handoffs(agent_3)
|
||||||
assert len(handoffs) == 2
|
assert len(handoffs) == 2
|
||||||
|
|
||||||
assert handoffs[0].agent_name == "agent_1"
|
assert handoffs[0].agent_name == "agent_1"
|
||||||
|
|
@ -111,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent():
|
||||||
handoffs=[handoff(agent_1), agent_2],
|
handoffs=[handoff(agent_1), agent_2],
|
||||||
)
|
)
|
||||||
|
|
||||||
handoffs = Runner._get_handoffs(agent_3)
|
handoffs = AgentRunner._get_handoffs(agent_3)
|
||||||
assert len(handoffs) == 2
|
assert len(handoffs) == 2
|
||||||
|
|
||||||
assert handoffs[0].agent_name == "agent_1"
|
assert handoffs[0].agent_name == "agent_1"
|
||||||
|
|
@ -159,7 +160,7 @@ async def test_agent_final_output():
|
||||||
output_type=Foo,
|
output_type=Foo,
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = Runner._get_output_schema(agent)
|
schema = AgentRunner._get_output_schema(agent)
|
||||||
assert isinstance(schema, AgentOutputSchema)
|
assert isinstance(schema, AgentOutputSchema)
|
||||||
assert schema is not None
|
assert schema is not None
|
||||||
assert schema.output_type == Foo
|
assert schema.output_type == Foo
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,10 @@ from agents import (
|
||||||
MessageOutputItem,
|
MessageOutputItem,
|
||||||
ModelBehaviorError,
|
ModelBehaviorError,
|
||||||
RunContextWrapper,
|
RunContextWrapper,
|
||||||
Runner,
|
|
||||||
UserError,
|
UserError,
|
||||||
handoff,
|
handoff,
|
||||||
)
|
)
|
||||||
|
from agents.run import AgentRunner
|
||||||
|
|
||||||
|
|
||||||
def message_item(content: str, agent: Agent[Any]) -> MessageOutputItem:
|
def message_item(content: str, agent: Agent[Any]) -> MessageOutputItem:
|
||||||
|
|
@ -45,9 +45,9 @@ def test_single_handoff_setup():
|
||||||
assert not agent_1.handoffs
|
assert not agent_1.handoffs
|
||||||
assert agent_2.handoffs == [agent_1]
|
assert agent_2.handoffs == [agent_1]
|
||||||
|
|
||||||
assert not Runner._get_handoffs(agent_1)
|
assert not AgentRunner._get_handoffs(agent_1)
|
||||||
|
|
||||||
handoff_objects = Runner._get_handoffs(agent_2)
|
handoff_objects = AgentRunner._get_handoffs(agent_2)
|
||||||
assert len(handoff_objects) == 1
|
assert len(handoff_objects) == 1
|
||||||
obj = handoff_objects[0]
|
obj = handoff_objects[0]
|
||||||
assert obj.tool_name == Handoff.default_tool_name(agent_1)
|
assert obj.tool_name == Handoff.default_tool_name(agent_1)
|
||||||
|
|
@ -64,7 +64,7 @@ def test_multiple_handoffs_setup():
|
||||||
assert not agent_1.handoffs
|
assert not agent_1.handoffs
|
||||||
assert not agent_2.handoffs
|
assert not agent_2.handoffs
|
||||||
|
|
||||||
handoff_objects = Runner._get_handoffs(agent_3)
|
handoff_objects = AgentRunner._get_handoffs(agent_3)
|
||||||
assert len(handoff_objects) == 2
|
assert len(handoff_objects) == 2
|
||||||
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
|
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
|
||||||
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
|
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
|
||||||
|
|
@ -95,7 +95,7 @@ def test_custom_handoff_setup():
|
||||||
assert not agent_1.handoffs
|
assert not agent_1.handoffs
|
||||||
assert not agent_2.handoffs
|
assert not agent_2.handoffs
|
||||||
|
|
||||||
handoff_objects = Runner._get_handoffs(agent_3)
|
handoff_objects = AgentRunner._get_handoffs(agent_3)
|
||||||
assert len(handoff_objects) == 2
|
assert len(handoff_objects) == 2
|
||||||
|
|
||||||
first_handoff = handoff_objects[0]
|
first_handoff = handoff_objects[0]
|
||||||
|
|
|
||||||
|
|
@ -10,16 +10,16 @@ from agents import (
|
||||||
AgentOutputSchema,
|
AgentOutputSchema,
|
||||||
AgentOutputSchemaBase,
|
AgentOutputSchemaBase,
|
||||||
ModelBehaviorError,
|
ModelBehaviorError,
|
||||||
Runner,
|
|
||||||
UserError,
|
UserError,
|
||||||
)
|
)
|
||||||
from agents.agent_output import _WRAPPER_DICT_KEY
|
from agents.agent_output import _WRAPPER_DICT_KEY
|
||||||
|
from agents.run import AgentRunner
|
||||||
from agents.util import _json
|
from agents.util import _json
|
||||||
|
|
||||||
|
|
||||||
def test_plain_text_output():
|
def test_plain_text_output():
|
||||||
agent = Agent(name="test")
|
agent = Agent(name="test")
|
||||||
output_schema = Runner._get_output_schema(agent)
|
output_schema = AgentRunner._get_output_schema(agent)
|
||||||
assert not output_schema, "Shouldn't have an output tool config without an output type"
|
assert not output_schema, "Shouldn't have an output tool config without an output type"
|
||||||
|
|
||||||
agent = Agent(name="test", output_type=str)
|
agent = Agent(name="test", output_type=str)
|
||||||
|
|
@ -32,7 +32,7 @@ class Foo(BaseModel):
|
||||||
|
|
||||||
def test_structured_output_pydantic():
|
def test_structured_output_pydantic():
|
||||||
agent = Agent(name="test", output_type=Foo)
|
agent = Agent(name="test", output_type=Foo)
|
||||||
output_schema = Runner._get_output_schema(agent)
|
output_schema = AgentRunner._get_output_schema(agent)
|
||||||
assert output_schema, "Should have an output tool config with a structured output type"
|
assert output_schema, "Should have an output tool config with a structured output type"
|
||||||
|
|
||||||
assert isinstance(output_schema, AgentOutputSchema)
|
assert isinstance(output_schema, AgentOutputSchema)
|
||||||
|
|
@ -52,7 +52,7 @@ class Bar(TypedDict):
|
||||||
|
|
||||||
def test_structured_output_typed_dict():
|
def test_structured_output_typed_dict():
|
||||||
agent = Agent(name="test", output_type=Bar)
|
agent = Agent(name="test", output_type=Bar)
|
||||||
output_schema = Runner._get_output_schema(agent)
|
output_schema = AgentRunner._get_output_schema(agent)
|
||||||
assert output_schema, "Should have an output tool config with a structured output type"
|
assert output_schema, "Should have an output tool config with a structured output type"
|
||||||
assert isinstance(output_schema, AgentOutputSchema)
|
assert isinstance(output_schema, AgentOutputSchema)
|
||||||
assert output_schema.output_type == Bar, "Should have the correct output type"
|
assert output_schema.output_type == Bar, "Should have the correct output type"
|
||||||
|
|
@ -65,7 +65,7 @@ def test_structured_output_typed_dict():
|
||||||
|
|
||||||
def test_structured_output_list():
|
def test_structured_output_list():
|
||||||
agent = Agent(name="test", output_type=list[str])
|
agent = Agent(name="test", output_type=list[str])
|
||||||
output_schema = Runner._get_output_schema(agent)
|
output_schema = AgentRunner._get_output_schema(agent)
|
||||||
assert output_schema, "Should have an output tool config with a structured output type"
|
assert output_schema, "Should have an output tool config with a structured output type"
|
||||||
assert isinstance(output_schema, AgentOutputSchema)
|
assert isinstance(output_schema, AgentOutputSchema)
|
||||||
assert output_schema.output_type == list[str], "Should have the correct output type"
|
assert output_schema.output_type == list[str], "Should have the correct output type"
|
||||||
|
|
@ -79,14 +79,14 @@ def test_structured_output_list():
|
||||||
|
|
||||||
def test_bad_json_raises_error(mocker):
|
def test_bad_json_raises_error(mocker):
|
||||||
agent = Agent(name="test", output_type=Foo)
|
agent = Agent(name="test", output_type=Foo)
|
||||||
output_schema = Runner._get_output_schema(agent)
|
output_schema = AgentRunner._get_output_schema(agent)
|
||||||
assert output_schema, "Should have an output tool config with a structured output type"
|
assert output_schema, "Should have an output tool config with a structured output type"
|
||||||
|
|
||||||
with pytest.raises(ModelBehaviorError):
|
with pytest.raises(ModelBehaviorError):
|
||||||
output_schema.validate_json("not valid json")
|
output_schema.validate_json("not valid json")
|
||||||
|
|
||||||
agent = Agent(name="test", output_type=list[str])
|
agent = Agent(name="test", output_type=list[str])
|
||||||
output_schema = Runner._get_output_schema(agent)
|
output_schema = AgentRunner._get_output_schema(agent)
|
||||||
assert output_schema, "Should have an output tool config with a structured output type"
|
assert output_schema, "Should have an output tool config with a structured output type"
|
||||||
|
|
||||||
mock_validate_json = mocker.patch.object(_json, "validate_json")
|
mock_validate_json = mocker.patch.object(_json, "validate_json")
|
||||||
|
|
@ -155,7 +155,7 @@ class CustomOutputSchema(AgentOutputSchemaBase):
|
||||||
def test_custom_output_schema():
|
def test_custom_output_schema():
|
||||||
custom_output_schema = CustomOutputSchema()
|
custom_output_schema = CustomOutputSchema()
|
||||||
agent = Agent(name="test", output_type=custom_output_schema)
|
agent = Agent(name="test", output_type=custom_output_schema)
|
||||||
output_schema = Runner._get_output_schema(agent)
|
output_schema = AgentRunner._get_output_schema(agent)
|
||||||
|
|
||||||
assert output_schema, "Should have an output tool config with a structured output type"
|
assert output_schema, "Should have an output tool config with a structured output type"
|
||||||
assert isinstance(output_schema, CustomOutputSchema)
|
assert isinstance(output_schema, CustomOutputSchema)
|
||||||
|
|
|
||||||
26
tests/test_run.py
Normal file
26
tests/test_run.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agents import Agent, Runner
|
||||||
|
from agents.run import AgentRunner, set_default_agent_runner
|
||||||
|
|
||||||
|
from .fake_model import FakeModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_run_methods_call_into_default_runner() -> None:
|
||||||
|
runner = mock.Mock(spec=AgentRunner)
|
||||||
|
set_default_agent_runner(runner)
|
||||||
|
|
||||||
|
agent = Agent(name="test", model=FakeModel())
|
||||||
|
await Runner.run(agent, input="test")
|
||||||
|
runner.run.assert_called_once()
|
||||||
|
|
||||||
|
Runner.run_streamed(agent, input="test")
|
||||||
|
runner.run_streamed.assert_called_once()
|
||||||
|
|
||||||
|
Runner.run_sync(agent, input="test")
|
||||||
|
runner.run_sync.assert_called_once()
|
||||||
|
|
@ -60,7 +60,7 @@ async def test_run_config_model_name_override_takes_precedence() -> None:
|
||||||
async def test_run_config_model_override_object_takes_precedence() -> None:
|
async def test_run_config_model_override_object_takes_precedence() -> None:
|
||||||
"""
|
"""
|
||||||
When a concrete Model instance is set on the RunConfig, then that instance should be
|
When a concrete Model instance is set on the RunConfig, then that instance should be
|
||||||
returned by Runner._get_model regardless of the agent's model.
|
returned by AgentRunner._get_model regardless of the agent's model.
|
||||||
"""
|
"""
|
||||||
fake_model = FakeModel(initial_output=[get_text_message("override-object")])
|
fake_model = FakeModel(initial_output=[get_text_message("override-object")])
|
||||||
agent = Agent(name="test", model="agent-model")
|
agent = Agent(name="test", model="agent-model")
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from agents import (
|
||||||
RunContextWrapper,
|
RunContextWrapper,
|
||||||
RunHooks,
|
RunHooks,
|
||||||
RunItem,
|
RunItem,
|
||||||
Runner,
|
|
||||||
ToolCallItem,
|
ToolCallItem,
|
||||||
ToolCallOutputItem,
|
ToolCallOutputItem,
|
||||||
TResponseInputItem,
|
TResponseInputItem,
|
||||||
|
|
@ -27,6 +26,7 @@ from agents._run_impl import (
|
||||||
RunImpl,
|
RunImpl,
|
||||||
SingleStepResult,
|
SingleStepResult,
|
||||||
)
|
)
|
||||||
|
from agents.run import AgentRunner
|
||||||
from agents.tool import function_tool
|
from agents.tool import function_tool
|
||||||
from agents.tool_context import ToolContext
|
from agents.tool_context import ToolContext
|
||||||
|
|
||||||
|
|
@ -324,8 +324,8 @@ async def get_execute_result(
|
||||||
context_wrapper: RunContextWrapper[Any] | None = None,
|
context_wrapper: RunContextWrapper[Any] | None = None,
|
||||||
run_config: RunConfig | None = None,
|
run_config: RunConfig | None = None,
|
||||||
) -> SingleStepResult:
|
) -> SingleStepResult:
|
||||||
output_schema = Runner._get_output_schema(agent)
|
output_schema = AgentRunner._get_output_schema(agent)
|
||||||
handoffs = Runner._get_handoffs(agent)
|
handoffs = AgentRunner._get_handoffs(agent)
|
||||||
|
|
||||||
processed_response = RunImpl.process_model_response(
|
processed_response = RunImpl.process_model_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,11 @@ from agents import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ReasoningItem,
|
ReasoningItem,
|
||||||
RunContextWrapper,
|
RunContextWrapper,
|
||||||
Runner,
|
|
||||||
ToolCallItem,
|
ToolCallItem,
|
||||||
Usage,
|
Usage,
|
||||||
)
|
)
|
||||||
from agents._run_impl import RunImpl
|
from agents._run_impl import RunImpl
|
||||||
|
from agents.run import AgentRunner
|
||||||
|
|
||||||
from .test_responses import (
|
from .test_responses import (
|
||||||
get_final_output_message,
|
get_final_output_message,
|
||||||
|
|
@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly():
|
||||||
agent=agent_3,
|
agent=agent_3,
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=AgentRunner._get_handoffs(agent_3),
|
||||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert len(result.handoffs) == 1, "Should have a handoff here"
|
assert len(result.handoffs) == 1, "Should have a handoff here"
|
||||||
|
|
@ -216,7 +216,7 @@ async def test_missing_handoff_fails():
|
||||||
agent=agent_3,
|
agent=agent_3,
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=AgentRunner._get_handoffs(agent_3),
|
||||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error():
|
||||||
agent=agent_3,
|
agent=agent_3,
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=AgentRunner._get_handoffs(agent_3),
|
||||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
|
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
|
||||||
|
|
@ -264,7 +264,7 @@ async def test_final_output_parsed_correctly():
|
||||||
RunImpl.process_model_response(
|
RunImpl.process_model_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=Runner._get_output_schema(agent),
|
output_schema=AgentRunner._get_output_schema(agent),
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
|
|
@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly():
|
||||||
agent=agent_3,
|
agent=agent_3,
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=AgentRunner._get_handoffs(agent_3),
|
||||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert result.functions and len(result.functions) == 1
|
assert result.functions and len(result.functions) == 1
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,12 @@ from agents.items import (
|
||||||
TResponseStreamEvent,
|
TResponseStreamEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ..fake_model import get_response_obj
|
||||||
|
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from agents.voice import SingleAgentVoiceWorkflow
|
from agents.voice import SingleAgentVoiceWorkflow
|
||||||
|
|
||||||
from ..fake_model import get_response_obj
|
|
||||||
from ..test_responses import get_function_tool, get_function_tool_call, get_text_message
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue