openai-agents-python/tests/src/agents/_run_impl.py
2025-03-11 09:42:28 -07:00

792 lines
28 KiB
Python

from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from openai.types.responses import (
ResponseComputerToolCall,
ResponseFileSearchToolCall,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseOutputMessage,
)
from openai.types.responses.response_computer_tool_call import (
ActionClick,
ActionDoubleClick,
ActionDrag,
ActionKeypress,
ActionMove,
ActionScreenshot,
ActionScroll,
ActionType,
ActionWait,
)
from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_output_item import Reasoning
from . import _utils
from .agent import Agent
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Computer
from .exceptions import AgentsException, ModelBehaviorError, UserError
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
from .handoffs import Handoff, HandoffInputData
from .items import (
HandoffCallItem,
HandoffOutputItem,
ItemHelpers,
MessageOutputItem,
ModelResponse,
ReasoningItem,
RunItem,
ToolCallItem,
ToolCallOutputItem,
TResponseInputItem,
)
from .lifecycle import RunHooks
from .logger import logger
from .models.interface import ModelTracing
from .run_context import RunContextWrapper, TContext
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import ComputerTool, FunctionTool
from .tracing import (
SpanError,
Trace,
function_span,
get_current_trace,
guardrail_span,
handoff_span,
trace,
)
if TYPE_CHECKING:
from .run import RunConfig
class QueueCompleteSentinel:
pass
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
@dataclass
class ToolRunHandoff:
handoff: Handoff
tool_call: ResponseFunctionToolCall
@dataclass
class ToolRunFunction:
tool_call: ResponseFunctionToolCall
function_tool: FunctionTool
@dataclass
class ToolRunComputerAction:
tool_call: ResponseComputerToolCall
computer_tool: ComputerTool
@dataclass
class ProcessedResponse:
new_items: list[RunItem]
handoffs: list[ToolRunHandoff]
functions: list[ToolRunFunction]
computer_actions: list[ToolRunComputerAction]
def has_tools_to_run(self) -> bool:
# Handoffs, functions and computer actions need local processing
# Hosted tools have already run, so there's nothing to do.
return any(
[
self.handoffs,
self.functions,
self.computer_actions,
]
)
@dataclass
class NextStepHandoff:
new_agent: Agent[Any]
@dataclass
class NextStepFinalOutput:
output: Any
@dataclass
class NextStepRunAgain:
pass
@dataclass
class SingleStepResult:
original_input: str | list[TResponseInputItem]
"""The input items i.e. the items before run() was called. May be mutated by handoff input
filters."""
model_response: ModelResponse
"""The model response for the current step."""
pre_step_items: list[RunItem]
"""Items generated before the current step."""
new_step_items: list[RunItem]
"""Items generated during this current step."""
next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain
"""The next step to take."""
@property
def generated_items(self) -> list[RunItem]:
"""Items generated during the agent run (i.e. everything generated after
`original_input`)."""
return self.pre_step_items + self.new_step_items
def get_model_tracing_impl(
tracing_disabled: bool, trace_include_sensitive_data: bool
) -> ModelTracing:
if tracing_disabled:
return ModelTracing.DISABLED
elif trace_include_sensitive_data:
return ModelTracing.ENABLED
else:
return ModelTracing.ENABLED_WITHOUT_DATA
class RunImpl:
@classmethod
async def execute_tools_and_side_effects(
cls,
*,
agent: Agent[TContext],
# The original input to the Runner
original_input: str | list[TResponseInputItem],
# Eveything generated by Runner since the original input, but before the current step
pre_step_items: list[RunItem],
new_response: ModelResponse,
processed_response: ProcessedResponse,
output_schema: AgentOutputSchema | None,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
) -> SingleStepResult:
# Make a copy of the generated items
pre_step_items = list(pre_step_items)
new_step_items: list[RunItem] = []
new_step_items.extend(processed_response.new_items)
# First, lets run the tool calls - function tools and computer actions
function_results, computer_results = await asyncio.gather(
cls.execute_function_tool_calls(
agent=agent,
tool_runs=processed_response.functions,
hooks=hooks,
context_wrapper=context_wrapper,
config=run_config,
),
cls.execute_computer_actions(
agent=agent,
actions=processed_response.computer_actions,
hooks=hooks,
context_wrapper=context_wrapper,
config=run_config,
),
)
new_step_items.extend(function_results)
new_step_items.extend(computer_results)
# Second, check if there are any handoffs
if run_handoffs := processed_response.handoffs:
return await cls.execute_handoffs(
agent=agent,
original_input=original_input,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
new_response=new_response,
run_handoffs=run_handoffs,
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
)
# Now we can check if the model also produced a final output
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
# We'll use the last content output as the final output
potential_final_output_text = (
ItemHelpers.extract_last_text(message_items[-1].raw_item) if message_items else None
)
# There are two possibilities that lead to a final output:
# 1. Structured output schema => always leads to a final output
# 2. Plain text output schema => only leads to a final output if there are no tool calls
if output_schema and not output_schema.is_plain_text() and potential_final_output_text:
final_output = output_schema.validate_json(potential_final_output_text)
return await cls.execute_final_output(
agent=agent,
original_input=original_input,
new_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
final_output=final_output,
hooks=hooks,
context_wrapper=context_wrapper,
)
elif (
not output_schema or output_schema.is_plain_text()
) and not processed_response.has_tools_to_run():
return await cls.execute_final_output(
agent=agent,
original_input=original_input,
new_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
final_output=potential_final_output_text or "",
hooks=hooks,
context_wrapper=context_wrapper,
)
else:
# If there's no final output, we can just run again
return SingleStepResult(
original_input=original_input,
model_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
next_step=NextStepRunAgain(),
)
@classmethod
def process_model_response(
cls,
*,
agent: Agent[Any],
response: ModelResponse,
output_schema: AgentOutputSchema | None,
handoffs: list[Handoff],
) -> ProcessedResponse:
items: list[RunItem] = []
run_handoffs = []
functions = []
computer_actions = []
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)}
computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None)
for output in response.output:
if isinstance(output, ResponseOutputMessage):
items.append(MessageOutputItem(raw_item=output, agent=agent))
elif isinstance(output, ResponseFileSearchToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
elif isinstance(output, ResponseFunctionWebSearch):
items.append(ToolCallItem(raw_item=output, agent=agent))
elif isinstance(output, Reasoning):
items.append(ReasoningItem(raw_item=output, agent=agent))
elif isinstance(output, ResponseComputerToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
if not computer_tool:
_utils.attach_error_to_current_span(
SpanError(
message="Computer tool not found",
data={},
)
)
raise ModelBehaviorError(
"Model produced computer action without a computer tool."
)
computer_actions.append(
ToolRunComputerAction(tool_call=output, computer_tool=computer_tool)
)
elif not isinstance(output, ResponseFunctionToolCall):
logger.warning(f"Unexpected output type, ignoring: {type(output)}")
continue
# At this point we know it's a function tool call
if not isinstance(output, ResponseFunctionToolCall):
continue
# Handoffs
if output.name in handoff_map:
items.append(HandoffCallItem(raw_item=output, agent=agent))
handoff = ToolRunHandoff(
tool_call=output,
handoff=handoff_map[output.name],
)
run_handoffs.append(handoff)
# Regular function tool call
else:
if output.name not in function_map:
_utils.attach_error_to_current_span(
SpanError(
message="Tool not found",
data={"tool_name": output.name},
)
)
raise ModelBehaviorError(f"Tool {output.name} not found in agent {agent.name}")
items.append(ToolCallItem(raw_item=output, agent=agent))
functions.append(
ToolRunFunction(
tool_call=output,
function_tool=function_map[output.name],
)
)
return ProcessedResponse(
new_items=items,
handoffs=run_handoffs,
functions=functions,
computer_actions=computer_actions,
)
@classmethod
async def execute_function_tool_calls(
cls,
*,
agent: Agent[TContext],
tool_runs: list[ToolRunFunction],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> list[RunItem]:
async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> str:
with function_span(func_tool.name) as span_fn:
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, func_tool),
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
if agent.hooks
else _utils.noop_coroutine()
),
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
)
await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
(
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
if agent.hooks
else _utils.noop_coroutine()
),
)
except Exception as e:
_utils.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": func_tool.name, "error": str(e)},
)
)
if isinstance(e, AgentsException):
raise e
raise UserError(f"Error running tool {func_tool.name}: {e}") from e
if config.trace_include_sensitive_data:
span_fn.span_data.output = result
return result
tasks = []
for tool_run in tool_runs:
function_tool = tool_run.function_tool
tasks.append(run_single_tool(function_tool, tool_run.tool_call))
results = await asyncio.gather(*tasks)
return [
ToolCallOutputItem(
output=str(result),
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
agent=agent,
)
for tool_run, result in zip(tool_runs, results)
]
@classmethod
async def execute_computer_actions(
cls,
*,
agent: Agent[TContext],
actions: list[ToolRunComputerAction],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> list[RunItem]:
results: list[RunItem] = []
# Need to run these serially, because each action can affect the computer state
for action in actions:
results.append(
await ComputerAction.execute(
agent=agent,
action=action,
hooks=hooks,
context_wrapper=context_wrapper,
config=config,
)
)
return results
@classmethod
async def execute_handoffs(
cls,
*,
agent: Agent[TContext],
original_input: str | list[TResponseInputItem],
pre_step_items: list[RunItem],
new_step_items: list[RunItem],
new_response: ModelResponse,
run_handoffs: list[ToolRunHandoff],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
) -> SingleStepResult:
# If there is more than one handoff, add tool responses that reject those handoffs
if len(run_handoffs) > 1:
output_message = "Multiple handoffs detected, ignoring this one."
new_step_items.extend(
[
ToolCallOutputItem(
output=output_message,
raw_item=ItemHelpers.tool_call_output_item(
handoff.tool_call, output_message
),
agent=agent,
)
for handoff in run_handoffs[1:]
]
)
actual_handoff = run_handoffs[0]
with handoff_span(from_agent=agent.name) as span_handoff:
handoff = actual_handoff.handoff
new_agent: Agent[Any] = await handoff.on_invoke_handoff(
context_wrapper, actual_handoff.tool_call.arguments
)
span_handoff.span_data.to_agent = new_agent.name
# Append a tool output item for the handoff
new_step_items.append(
HandoffOutputItem(
agent=agent,
raw_item=ItemHelpers.tool_call_output_item(
actual_handoff.tool_call,
handoff.get_transfer_message(new_agent),
),
source_agent=agent,
target_agent=new_agent,
)
)
# Execute handoff hooks
await asyncio.gather(
hooks.on_handoff(
context=context_wrapper,
from_agent=agent,
to_agent=new_agent,
),
(
agent.hooks.on_handoff(
context_wrapper,
agent=new_agent,
source=agent,
)
if agent.hooks
else _utils.noop_coroutine()
),
)
# If there's an input filter, filter the input for the next agent
input_filter = handoff.input_filter or (
run_config.handoff_input_filter if run_config else None
)
if input_filter:
logger.debug("Filtering inputs for handoff")
handoff_input_data = HandoffInputData(
input_history=tuple(original_input)
if isinstance(original_input, list)
else original_input,
pre_handoff_items=tuple(pre_step_items),
new_items=tuple(new_step_items),
)
if not callable(input_filter):
_utils.attach_error_to_span(
span_handoff,
SpanError(
message="Invalid input filter",
data={"details": "not callable()"},
),
)
raise UserError(f"Invalid input filter: {input_filter}")
filtered = input_filter(handoff_input_data)
if not isinstance(filtered, HandoffInputData):
_utils.attach_error_to_span(
span_handoff,
SpanError(
message="Invalid input filter result",
data={"details": "not a HandoffInputData"},
),
)
raise UserError(f"Invalid input filter result: {filtered}")
original_input = (
filtered.input_history
if isinstance(filtered.input_history, str)
else list(filtered.input_history)
)
pre_step_items = list(filtered.pre_handoff_items)
new_step_items = list(filtered.new_items)
return SingleStepResult(
original_input=original_input,
model_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
next_step=NextStepHandoff(new_agent),
)
@classmethod
async def execute_final_output(
cls,
*,
agent: Agent[TContext],
original_input: str | list[TResponseInputItem],
new_response: ModelResponse,
pre_step_items: list[RunItem],
new_step_items: list[RunItem],
final_output: Any,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
) -> SingleStepResult:
# Run the on_end hooks
await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output)
return SingleStepResult(
original_input=original_input,
model_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
next_step=NextStepFinalOutput(final_output),
)
@classmethod
async def run_final_output_hooks(
cls,
agent: Agent[TContext],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
final_output: Any,
):
await asyncio.gather(
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _utils.noop_coroutine(),
)
@classmethod
async def run_single_input_guardrail(
cls,
agent: Agent[Any],
guardrail: InputGuardrail[TContext],
input: str | list[TResponseInputItem],
context: RunContextWrapper[TContext],
) -> InputGuardrailResult:
with guardrail_span(guardrail.get_name()) as span_guardrail:
result = await guardrail.run(agent, input, context)
span_guardrail.span_data.triggered = result.output.tripwire_triggered
return result
@classmethod
async def run_single_output_guardrail(
cls,
guardrail: OutputGuardrail[TContext],
agent: Agent[Any],
agent_output: Any,
context: RunContextWrapper[TContext],
) -> OutputGuardrailResult:
with guardrail_span(guardrail.get_name()) as span_guardrail:
result = await guardrail.run(agent=agent, agent_output=agent_output, context=context)
span_guardrail.span_data.triggered = result.output.tripwire_triggered
return result
@classmethod
def stream_step_result_to_queue(
cls,
step_result: SingleStepResult,
queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel],
):
for item in step_result.new_step_items:
if isinstance(item, MessageOutputItem):
event = RunItemStreamEvent(item=item, name="message_output_created")
elif isinstance(item, HandoffCallItem):
event = RunItemStreamEvent(item=item, name="handoff_requested")
elif isinstance(item, HandoffOutputItem):
event = RunItemStreamEvent(item=item, name="handoff_occured")
elif isinstance(item, ToolCallItem):
event = RunItemStreamEvent(item=item, name="tool_called")
elif isinstance(item, ToolCallOutputItem):
event = RunItemStreamEvent(item=item, name="tool_output")
elif isinstance(item, ReasoningItem):
event = RunItemStreamEvent(item=item, name="reasoning_item_created")
else:
logger.warning(f"Unexpected item type: {type(item)}")
event = None
if event:
queue.put_nowait(event)
class TraceCtxManager:
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""
def __init__(
self,
workflow_name: str,
trace_id: str | None,
group_id: str | None,
metadata: dict[str, Any] | None,
disabled: bool,
):
self.trace: Trace | None = None
self.workflow_name = workflow_name
self.trace_id = trace_id
self.group_id = group_id
self.metadata = metadata
self.disabled = disabled
def __enter__(self) -> TraceCtxManager:
current_trace = get_current_trace()
if not current_trace:
self.trace = trace(
workflow_name=self.workflow_name,
trace_id=self.trace_id,
group_id=self.group_id,
metadata=self.metadata,
disabled=self.disabled,
)
self.trace.start(mark_as_current=True)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.trace:
self.trace.finish(reset_current=True)
class ComputerAction:
@classmethod
async def execute(
cls,
*,
agent: Agent[TContext],
action: ToolRunComputerAction,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> RunItem:
output_func = (
cls._get_screenshot_async(action.computer_tool.computer, action.tool_call)
if isinstance(action.computer_tool.computer, AsyncComputer)
else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call)
)
_, _, output = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
(
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
if agent.hooks
else _utils.noop_coroutine()
),
output_func,
)
await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
(
agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
if agent.hooks
else _utils.noop_coroutine()
),
)
# TODO: don't send a screenshot every single time, use references
image_url = f"data:image/png;base64,{output}"
return ToolCallOutputItem(
agent=agent,
output=image_url,
raw_item=ComputerCallOutput(
call_id=action.tool_call.call_id,
output={
"type": "computer_screenshot",
"image_url": image_url,
},
type="computer_call_output",
),
)
@classmethod
async def _get_screenshot_sync(
cls,
computer: Computer,
tool_call: ResponseComputerToolCall,
) -> str:
action = tool_call.action
if isinstance(action, ActionClick):
computer.click(action.x, action.y, action.button)
elif isinstance(action, ActionDoubleClick):
computer.double_click(action.x, action.y)
elif isinstance(action, ActionDrag):
computer.drag([(p.x, p.y) for p in action.path])
elif isinstance(action, ActionKeypress):
computer.keypress(action.keys)
elif isinstance(action, ActionMove):
computer.move(action.x, action.y)
elif isinstance(action, ActionScreenshot):
computer.screenshot()
elif isinstance(action, ActionScroll):
computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y)
elif isinstance(action, ActionType):
computer.type(action.text)
elif isinstance(action, ActionWait):
computer.wait()
return computer.screenshot()
@classmethod
async def _get_screenshot_async(
cls,
computer: AsyncComputer,
tool_call: ResponseComputerToolCall,
) -> str:
action = tool_call.action
if isinstance(action, ActionClick):
await computer.click(action.x, action.y, action.button)
elif isinstance(action, ActionDoubleClick):
await computer.double_click(action.x, action.y)
elif isinstance(action, ActionDrag):
await computer.drag([(p.x, p.y) for p in action.path])
elif isinstance(action, ActionKeypress):
await computer.keypress(action.keys)
elif isinstance(action, ActionMove):
await computer.move(action.x, action.y)
elif isinstance(action, ActionScreenshot):
await computer.screenshot()
elif isinstance(action, ActionScroll):
await computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y)
elif isinstance(action, ActionType):
await computer.type(action.text)
elif isinstance(action, ActionWait):
await computer.wait()
return await computer.screenshot()