904 lines
36 KiB
Python
904 lines
36 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import copy
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, cast
|
|
|
|
from openai.types.responses import ResponseCompletedEvent
|
|
|
|
from . import Model, _utils
|
|
from ._run_impl import (
|
|
NextStepFinalOutput,
|
|
NextStepHandoff,
|
|
NextStepRunAgain,
|
|
QueueCompleteSentinel,
|
|
RunImpl,
|
|
SingleStepResult,
|
|
TraceCtxManager,
|
|
get_model_tracing_impl,
|
|
)
|
|
from .agent import Agent
|
|
from .agent_output import AgentOutputSchema
|
|
from .exceptions import (
|
|
AgentsException,
|
|
InputGuardrailTripwireTriggered,
|
|
MaxTurnsExceeded,
|
|
ModelBehaviorError,
|
|
OutputGuardrailTripwireTriggered,
|
|
)
|
|
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
|
|
from .handoffs import Handoff, HandoffInputFilter, handoff
|
|
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
|
from .lifecycle import RunHooks
|
|
from .logger import logger
|
|
from .model_settings import ModelSettings
|
|
from .models.interface import ModelProvider
|
|
from .models.openai_provider import OpenAIProvider
|
|
from .result import RunResult, RunResultStreaming
|
|
from .run_context import RunContextWrapper, TContext
|
|
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
|
|
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
|
|
from .tracing.span_data import AgentSpanData
|
|
from .usage import Usage
|
|
|
|
DEFAULT_MAX_TURNS = 10
|
|
|
|
|
|
@dataclass
|
|
class RunConfig:
|
|
"""Configures settings for the entire agent run."""
|
|
|
|
model: str | Model | None = None
|
|
"""The model to use for the entire agent run. If set, will override the model set on every
|
|
agent. The model_provider passed in below must be able to resolve this model name.
|
|
"""
|
|
|
|
model_provider: ModelProvider = field(default_factory=OpenAIProvider)
|
|
"""The model provider to use when looking up string model names. Defaults to OpenAI."""
|
|
|
|
model_settings: ModelSettings | None = None
|
|
"""Configure global model settings. Any non-null values will override the agent-specific model
|
|
settings.
|
|
"""
|
|
|
|
handoff_input_filter: HandoffInputFilter | None = None
|
|
"""A global input filter to apply to all handoffs. If `Handoff.input_filter` is set, then that
|
|
will take precedence. The input filter allows you to edit the inputs that are sent to the new
|
|
agent. See the documentation in `Handoff.input_filter` for more details.
|
|
"""
|
|
|
|
input_guardrails: list[InputGuardrail[Any]] | None = None
|
|
"""A list of input guardrails to run on the initial run input."""
|
|
|
|
output_guardrails: list[OutputGuardrail[Any]] | None = None
|
|
"""A list of output guardrails to run on the final output of the run."""
|
|
|
|
tracing_disabled: bool = False
|
|
"""Whether tracing is disabled for the agent run. If disabled, we will not trace the agent run.
|
|
"""
|
|
|
|
trace_include_sensitive_data: bool = True
|
|
"""Whether we include potentially sensitive data (for example: inputs/outputs of tool calls or
|
|
LLM generations) in traces. If False, we'll still create spans for these events, but the
|
|
sensitive data will not be included.
|
|
"""
|
|
|
|
workflow_name: str = "Agent workflow"
|
|
"""The name of the run, used for tracing. Should be a logical name for the run, like
|
|
"Code generation workflow" or "Customer support agent".
|
|
"""
|
|
|
|
trace_id: str | None = None
|
|
"""A custom trace ID to use for tracing. If not provided, we will generate a new trace ID."""
|
|
|
|
group_id: str | None = None
|
|
"""
|
|
A grouping identifier to use for tracing, to link multiple traces from the same conversation
|
|
or process. For example, you might use a chat thread ID.
|
|
"""
|
|
|
|
trace_metadata: dict[str, Any] | None = None
|
|
"""
|
|
An optional dictionary of additional metadata to include with the trace.
|
|
"""
|
|
|
|
|
|
class Runner:
|
|
@classmethod
|
|
async def run(
|
|
cls,
|
|
starting_agent: Agent[TContext],
|
|
input: str | list[TResponseInputItem],
|
|
*,
|
|
context: TContext | None = None,
|
|
max_turns: int = DEFAULT_MAX_TURNS,
|
|
hooks: RunHooks[TContext] | None = None,
|
|
run_config: RunConfig | None = None,
|
|
) -> RunResult:
|
|
"""Run a workflow starting at the given agent. The agent will run in a loop until a final
|
|
output is generated. The loop runs like so:
|
|
1. The agent is invoked with the given input.
|
|
2. If there is a final output (i.e. the agent produces something of type
|
|
`agent.output_type`, the loop terminates.
|
|
3. If there's a handoff, we run the loop again, with the new agent.
|
|
4. Else, we run tool calls (if any), and re-run the loop.
|
|
|
|
In two cases, the agent may raise an exception:
|
|
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
|
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
|
|
|
Note that only the first agent's input guardrails are run.
|
|
|
|
Args:
|
|
starting_agent: The starting agent to run.
|
|
input: The initial input to the agent. You can pass a single string for a user message,
|
|
or a list of input items.
|
|
context: The context to run the agent with.
|
|
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
|
AI invocation (including any tool calls that might occur).
|
|
hooks: An object that receives callbacks on various lifecycle events.
|
|
run_config: Global settings for the entire agent run.
|
|
|
|
Returns:
|
|
A run result containing all the inputs, guardrail results and the output of the last
|
|
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
|
"""
|
|
if hooks is None:
|
|
hooks = RunHooks[Any]()
|
|
if run_config is None:
|
|
run_config = RunConfig()
|
|
|
|
with TraceCtxManager(
|
|
workflow_name=run_config.workflow_name,
|
|
trace_id=run_config.trace_id,
|
|
group_id=run_config.group_id,
|
|
metadata=run_config.trace_metadata,
|
|
disabled=run_config.tracing_disabled,
|
|
):
|
|
current_turn = 0
|
|
original_input: str | list[TResponseInputItem] = copy.deepcopy(input)
|
|
generated_items: list[RunItem] = []
|
|
model_responses: list[ModelResponse] = []
|
|
|
|
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
|
|
context=context, # type: ignore
|
|
)
|
|
|
|
input_guardrail_results: list[InputGuardrailResult] = []
|
|
|
|
current_span: Span[AgentSpanData] | None = None
|
|
current_agent = starting_agent
|
|
should_run_agent_start_hooks = True
|
|
|
|
try:
|
|
while True:
|
|
# Start an agent span if we don't have one. This span is ended if the current
|
|
# agent changes, or if the agent loop ends.
|
|
if current_span is None:
|
|
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
|
|
tool_names = [t.name for t in current_agent.tools]
|
|
if output_schema := cls._get_output_schema(current_agent):
|
|
output_type_name = output_schema.output_type_name()
|
|
else:
|
|
output_type_name = "str"
|
|
|
|
current_span = agent_span(
|
|
name=current_agent.name,
|
|
handoffs=handoff_names,
|
|
tools=tool_names,
|
|
output_type=output_type_name,
|
|
)
|
|
current_span.start(mark_as_current=True)
|
|
|
|
current_turn += 1
|
|
if current_turn > max_turns:
|
|
_utils.attach_error_to_span(
|
|
current_span,
|
|
SpanError(
|
|
message="Max turns exceeded",
|
|
data={"max_turns": max_turns},
|
|
),
|
|
)
|
|
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
|
|
|
|
logger.debug(
|
|
f"Running agent {current_agent.name} (turn {current_turn})",
|
|
)
|
|
|
|
if current_turn == 1:
|
|
input_guardrail_results, turn_result = await asyncio.gather(
|
|
cls._run_input_guardrails(
|
|
starting_agent,
|
|
starting_agent.input_guardrails
|
|
+ (run_config.input_guardrails or []),
|
|
copy.deepcopy(input),
|
|
context_wrapper,
|
|
),
|
|
cls._run_single_turn(
|
|
agent=current_agent,
|
|
original_input=original_input,
|
|
generated_items=generated_items,
|
|
hooks=hooks,
|
|
context_wrapper=context_wrapper,
|
|
run_config=run_config,
|
|
should_run_agent_start_hooks=should_run_agent_start_hooks,
|
|
),
|
|
)
|
|
else:
|
|
turn_result = await cls._run_single_turn(
|
|
agent=current_agent,
|
|
original_input=original_input,
|
|
generated_items=generated_items,
|
|
hooks=hooks,
|
|
context_wrapper=context_wrapper,
|
|
run_config=run_config,
|
|
should_run_agent_start_hooks=should_run_agent_start_hooks,
|
|
)
|
|
should_run_agent_start_hooks = False
|
|
|
|
model_responses.append(turn_result.model_response)
|
|
original_input = turn_result.original_input
|
|
generated_items = turn_result.generated_items
|
|
|
|
if isinstance(turn_result.next_step, NextStepFinalOutput):
|
|
output_guardrail_results = await cls._run_output_guardrails(
|
|
current_agent.output_guardrails + (run_config.output_guardrails or []),
|
|
current_agent,
|
|
turn_result.next_step.output,
|
|
context_wrapper,
|
|
)
|
|
return RunResult(
|
|
input=original_input,
|
|
new_items=generated_items,
|
|
raw_responses=model_responses,
|
|
final_output=turn_result.next_step.output,
|
|
_last_agent=current_agent,
|
|
input_guardrail_results=input_guardrail_results,
|
|
output_guardrail_results=output_guardrail_results,
|
|
)
|
|
elif isinstance(turn_result.next_step, NextStepHandoff):
|
|
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
|
|
current_span.finish(reset_current=True)
|
|
current_span = None
|
|
should_run_agent_start_hooks = True
|
|
elif isinstance(turn_result.next_step, NextStepRunAgain):
|
|
pass
|
|
else:
|
|
raise AgentsException(
|
|
f"Unknown next step type: {type(turn_result.next_step)}"
|
|
)
|
|
finally:
|
|
if current_span:
|
|
current_span.finish(reset_current=True)
|
|
|
|
@classmethod
|
|
def run_sync(
|
|
cls,
|
|
starting_agent: Agent[TContext],
|
|
input: str | list[TResponseInputItem],
|
|
*,
|
|
context: TContext | None = None,
|
|
max_turns: int = DEFAULT_MAX_TURNS,
|
|
hooks: RunHooks[TContext] | None = None,
|
|
run_config: RunConfig | None = None,
|
|
) -> RunResult:
|
|
"""Run a workflow synchronously, starting at the given agent. Note that this just wraps the
|
|
`run` method, so it will not work if there's already an event loop (e.g. inside an async
|
|
function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
|
|
the `run` method instead.
|
|
|
|
The agent will run in a loop until a final output is generated. The loop runs like so:
|
|
1. The agent is invoked with the given input.
|
|
2. If there is a final output (i.e. the agent produces something of type
|
|
`agent.output_type`, the loop terminates.
|
|
3. If there's a handoff, we run the loop again, with the new agent.
|
|
4. Else, we run tool calls (if any), and re-run the loop.
|
|
|
|
In two cases, the agent may raise an exception:
|
|
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
|
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
|
|
|
Note that only the first agent's input guardrails are run.
|
|
|
|
Args:
|
|
starting_agent: The starting agent to run.
|
|
input: The initial input to the agent. You can pass a single string for a user message,
|
|
or a list of input items.
|
|
context: The context to run the agent with.
|
|
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
|
AI invocation (including any tool calls that might occur).
|
|
hooks: An object that receives callbacks on various lifecycle events.
|
|
run_config: Global settings for the entire agent run.
|
|
|
|
Returns:
|
|
A run result containing all the inputs, guardrail results and the output of the last
|
|
agent. Agents may perform handoffs, so we don't know the specific type of the output.
|
|
"""
|
|
return asyncio.get_event_loop().run_until_complete(
|
|
cls.run(
|
|
starting_agent,
|
|
input,
|
|
context=context,
|
|
max_turns=max_turns,
|
|
hooks=hooks,
|
|
run_config=run_config,
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def run_streamed(
|
|
cls,
|
|
starting_agent: Agent[TContext],
|
|
input: str | list[TResponseInputItem],
|
|
context: TContext | None = None,
|
|
max_turns: int = DEFAULT_MAX_TURNS,
|
|
hooks: RunHooks[TContext] | None = None,
|
|
run_config: RunConfig | None = None,
|
|
) -> RunResultStreaming:
|
|
"""Run a workflow starting at the given agent in streaming mode. The returned result object
|
|
contains a method you can use to stream semantic events as they are generated.
|
|
|
|
The agent will run in a loop until a final output is generated. The loop runs like so:
|
|
1. The agent is invoked with the given input.
|
|
2. If there is a final output (i.e. the agent produces something of type
|
|
`agent.output_type`, the loop terminates.
|
|
3. If there's a handoff, we run the loop again, with the new agent.
|
|
4. Else, we run tool calls (if any), and re-run the loop.
|
|
|
|
In two cases, the agent may raise an exception:
|
|
1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
|
|
2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
|
|
|
|
Note that only the first agent's input guardrails are run.
|
|
|
|
Args:
|
|
starting_agent: The starting agent to run.
|
|
input: The initial input to the agent. You can pass a single string for a user message,
|
|
or a list of input items.
|
|
context: The context to run the agent with.
|
|
max_turns: The maximum number of turns to run the agent for. A turn is defined as one
|
|
AI invocation (including any tool calls that might occur).
|
|
hooks: An object that receives callbacks on various lifecycle events.
|
|
run_config: Global settings for the entire agent run.
|
|
|
|
Returns:
|
|
A result object that contains data about the run, as well as a method to stream events.
|
|
"""
|
|
if hooks is None:
|
|
hooks = RunHooks[Any]()
|
|
if run_config is None:
|
|
run_config = RunConfig()
|
|
|
|
# If there's already a trace, we don't create a new one. In addition, we can't end the
|
|
# trace here, because the actual work is done in `stream_events` and this method ends
|
|
# before that.
|
|
new_trace = (
|
|
None
|
|
if get_current_trace()
|
|
else trace(
|
|
workflow_name=run_config.workflow_name,
|
|
trace_id=run_config.trace_id,
|
|
group_id=run_config.group_id,
|
|
metadata=run_config.trace_metadata,
|
|
disabled=run_config.tracing_disabled,
|
|
)
|
|
)
|
|
# Need to start the trace here, because the current trace contextvar is captured at
|
|
# asyncio.create_task time
|
|
if new_trace:
|
|
new_trace.start(mark_as_current=True)
|
|
|
|
output_schema = cls._get_output_schema(starting_agent)
|
|
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
|
|
context=context # type: ignore
|
|
)
|
|
|
|
streamed_result = RunResultStreaming(
|
|
input=copy.deepcopy(input),
|
|
new_items=[],
|
|
current_agent=starting_agent,
|
|
raw_responses=[],
|
|
final_output=None,
|
|
is_complete=False,
|
|
current_turn=0,
|
|
max_turns=max_turns,
|
|
input_guardrail_results=[],
|
|
output_guardrail_results=[],
|
|
_current_agent_output_schema=output_schema,
|
|
_trace=new_trace,
|
|
)
|
|
|
|
# Kick off the actual agent loop in the background and return the streamed result object.
|
|
streamed_result._run_impl_task = asyncio.create_task(
|
|
cls._run_streamed_impl(
|
|
starting_input=input,
|
|
streamed_result=streamed_result,
|
|
starting_agent=starting_agent,
|
|
max_turns=max_turns,
|
|
hooks=hooks,
|
|
context_wrapper=context_wrapper,
|
|
run_config=run_config,
|
|
)
|
|
)
|
|
return streamed_result
|
|
|
|
@classmethod
|
|
async def _run_input_guardrails_with_queue(
|
|
cls,
|
|
agent: Agent[Any],
|
|
guardrails: list[InputGuardrail[TContext]],
|
|
input: str | list[TResponseInputItem],
|
|
context: RunContextWrapper[TContext],
|
|
streamed_result: RunResultStreaming,
|
|
parent_span: Span[Any],
|
|
):
|
|
queue = streamed_result._input_guardrail_queue
|
|
|
|
# We'll run the guardrails and push them onto the queue as they complete
|
|
guardrail_tasks = [
|
|
asyncio.create_task(
|
|
RunImpl.run_single_input_guardrail(agent, guardrail, input, context)
|
|
)
|
|
for guardrail in guardrails
|
|
]
|
|
guardrail_results = []
|
|
try:
|
|
for done in asyncio.as_completed(guardrail_tasks):
|
|
result = await done
|
|
if result.output.tripwire_triggered:
|
|
_utils.attach_error_to_span(
|
|
parent_span,
|
|
SpanError(
|
|
message="Guardrail tripwire triggered",
|
|
data={
|
|
"guardrail": result.guardrail.get_name(),
|
|
"type": "input_guardrail",
|
|
},
|
|
),
|
|
)
|
|
queue.put_nowait(result)
|
|
guardrail_results.append(result)
|
|
except Exception:
|
|
for t in guardrail_tasks:
|
|
t.cancel()
|
|
raise
|
|
|
|
streamed_result.input_guardrail_results = guardrail_results
|
|
|
|
@classmethod
|
|
async def _run_streamed_impl(
|
|
cls,
|
|
starting_input: str | list[TResponseInputItem],
|
|
streamed_result: RunResultStreaming,
|
|
starting_agent: Agent[TContext],
|
|
max_turns: int,
|
|
hooks: RunHooks[TContext],
|
|
context_wrapper: RunContextWrapper[TContext],
|
|
run_config: RunConfig,
|
|
):
|
|
current_span: Span[AgentSpanData] | None = None
|
|
current_agent = starting_agent
|
|
current_turn = 0
|
|
should_run_agent_start_hooks = True
|
|
|
|
streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent))
|
|
|
|
try:
|
|
while True:
|
|
if streamed_result.is_complete:
|
|
break
|
|
|
|
# Start an agent span if we don't have one. This span is ended if the current
|
|
# agent changes, or if the agent loop ends.
|
|
if current_span is None:
|
|
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
|
|
tool_names = [t.name for t in current_agent.tools]
|
|
if output_schema := cls._get_output_schema(current_agent):
|
|
output_type_name = output_schema.output_type_name()
|
|
else:
|
|
output_type_name = "str"
|
|
|
|
current_span = agent_span(
|
|
name=current_agent.name,
|
|
handoffs=handoff_names,
|
|
tools=tool_names,
|
|
output_type=output_type_name,
|
|
)
|
|
current_span.start(mark_as_current=True)
|
|
|
|
current_turn += 1
|
|
streamed_result.current_turn = current_turn
|
|
|
|
if current_turn > max_turns:
|
|
_utils.attach_error_to_span(
|
|
current_span,
|
|
SpanError(
|
|
message="Max turns exceeded",
|
|
data={"max_turns": max_turns},
|
|
),
|
|
)
|
|
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
|
|
break
|
|
|
|
if current_turn == 1:
|
|
# Run the input guardrails in the background and put the results on the queue
|
|
streamed_result._input_guardrails_task = asyncio.create_task(
|
|
cls._run_input_guardrails_with_queue(
|
|
starting_agent,
|
|
starting_agent.input_guardrails + (run_config.input_guardrails or []),
|
|
copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)),
|
|
context_wrapper,
|
|
streamed_result,
|
|
current_span,
|
|
)
|
|
)
|
|
try:
|
|
turn_result = await cls._run_single_turn_streamed(
|
|
streamed_result,
|
|
current_agent,
|
|
hooks,
|
|
context_wrapper,
|
|
run_config,
|
|
should_run_agent_start_hooks,
|
|
)
|
|
should_run_agent_start_hooks = False
|
|
|
|
streamed_result.raw_responses = streamed_result.raw_responses + [
|
|
turn_result.model_response
|
|
]
|
|
streamed_result.input = turn_result.original_input
|
|
streamed_result.new_items = turn_result.generated_items
|
|
|
|
if isinstance(turn_result.next_step, NextStepHandoff):
|
|
current_agent = turn_result.next_step.new_agent
|
|
current_span.finish(reset_current=True)
|
|
current_span = None
|
|
should_run_agent_start_hooks = True
|
|
streamed_result._event_queue.put_nowait(
|
|
AgentUpdatedStreamEvent(new_agent=current_agent)
|
|
)
|
|
elif isinstance(turn_result.next_step, NextStepFinalOutput):
|
|
streamed_result._output_guardrails_task = asyncio.create_task(
|
|
cls._run_output_guardrails(
|
|
current_agent.output_guardrails
|
|
+ (run_config.output_guardrails or []),
|
|
current_agent,
|
|
turn_result.next_step.output,
|
|
context_wrapper,
|
|
)
|
|
)
|
|
|
|
try:
|
|
output_guardrail_results = await streamed_result._output_guardrails_task
|
|
except Exception:
|
|
# Exceptions will be checked in the stream_events loop
|
|
output_guardrail_results = []
|
|
|
|
streamed_result.output_guardrail_results = output_guardrail_results
|
|
streamed_result.final_output = turn_result.next_step.output
|
|
streamed_result.is_complete = True
|
|
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
|
|
elif isinstance(turn_result.next_step, NextStepRunAgain):
|
|
pass
|
|
except Exception as e:
|
|
if current_span:
|
|
_utils.attach_error_to_span(
|
|
current_span,
|
|
SpanError(
|
|
message="Error in agent run",
|
|
data={"error": str(e)},
|
|
),
|
|
)
|
|
streamed_result.is_complete = True
|
|
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
|
|
raise
|
|
|
|
streamed_result.is_complete = True
|
|
finally:
|
|
if current_span:
|
|
current_span.finish(reset_current=True)
|
|
|
|
@classmethod
|
|
async def _run_single_turn_streamed(
|
|
cls,
|
|
streamed_result: RunResultStreaming,
|
|
agent: Agent[TContext],
|
|
hooks: RunHooks[TContext],
|
|
context_wrapper: RunContextWrapper[TContext],
|
|
run_config: RunConfig,
|
|
should_run_agent_start_hooks: bool,
|
|
) -> SingleStepResult:
|
|
if should_run_agent_start_hooks:
|
|
await asyncio.gather(
|
|
hooks.on_agent_start(context_wrapper, agent),
|
|
(
|
|
agent.hooks.on_start(context_wrapper, agent)
|
|
if agent.hooks
|
|
else _utils.noop_coroutine()
|
|
),
|
|
)
|
|
|
|
output_schema = cls._get_output_schema(agent)
|
|
|
|
streamed_result.current_agent = agent
|
|
streamed_result._current_agent_output_schema = output_schema
|
|
|
|
system_prompt = await agent.get_system_prompt(context_wrapper)
|
|
|
|
handoffs = cls._get_handoffs(agent)
|
|
|
|
model = cls._get_model(agent, run_config)
|
|
model_settings = agent.model_settings.resolve(run_config.model_settings)
|
|
final_response: ModelResponse | None = None
|
|
|
|
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
|
|
input.extend([item.to_input_item() for item in streamed_result.new_items])
|
|
|
|
# 1. Stream the output events
|
|
async for event in model.stream_response(
|
|
system_prompt,
|
|
input,
|
|
model_settings,
|
|
agent.tools,
|
|
output_schema,
|
|
handoffs,
|
|
get_model_tracing_impl(
|
|
run_config.tracing_disabled, run_config.trace_include_sensitive_data
|
|
),
|
|
):
|
|
if isinstance(event, ResponseCompletedEvent):
|
|
usage = (
|
|
Usage(
|
|
requests=1,
|
|
input_tokens=event.response.usage.input_tokens,
|
|
output_tokens=event.response.usage.output_tokens,
|
|
total_tokens=event.response.usage.total_tokens,
|
|
)
|
|
if event.response.usage
|
|
else Usage()
|
|
)
|
|
final_response = ModelResponse(
|
|
output=event.response.output,
|
|
usage=usage,
|
|
referenceable_id=event.response.id,
|
|
)
|
|
|
|
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
|
|
|
|
# 2. At this point, the streaming is complete for this turn of the agent loop.
|
|
if not final_response:
|
|
raise ModelBehaviorError("Model did not produce a final response!")
|
|
|
|
# 3. Now, we can process the turn as we do in the non-streaming case
|
|
single_step_result = await cls._get_single_step_result_from_response(
|
|
agent=agent,
|
|
original_input=streamed_result.input,
|
|
pre_step_items=streamed_result.new_items,
|
|
new_response=final_response,
|
|
output_schema=output_schema,
|
|
handoffs=handoffs,
|
|
hooks=hooks,
|
|
context_wrapper=context_wrapper,
|
|
run_config=run_config,
|
|
)
|
|
|
|
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
|
|
return single_step_result
|
|
|
|
@classmethod
|
|
async def _run_single_turn(
|
|
cls,
|
|
*,
|
|
agent: Agent[TContext],
|
|
original_input: str | list[TResponseInputItem],
|
|
generated_items: list[RunItem],
|
|
hooks: RunHooks[TContext],
|
|
context_wrapper: RunContextWrapper[TContext],
|
|
run_config: RunConfig,
|
|
should_run_agent_start_hooks: bool,
|
|
) -> SingleStepResult:
|
|
# Ensure we run the hooks before anything else
|
|
if should_run_agent_start_hooks:
|
|
await asyncio.gather(
|
|
hooks.on_agent_start(context_wrapper, agent),
|
|
(
|
|
agent.hooks.on_start(context_wrapper, agent)
|
|
if agent.hooks
|
|
else _utils.noop_coroutine()
|
|
),
|
|
)
|
|
|
|
system_prompt = await agent.get_system_prompt(context_wrapper)
|
|
|
|
output_schema = cls._get_output_schema(agent)
|
|
handoffs = cls._get_handoffs(agent)
|
|
input = ItemHelpers.input_to_new_input_list(original_input)
|
|
input.extend([generated_item.to_input_item() for generated_item in generated_items])
|
|
|
|
new_response = await cls._get_new_response(
|
|
agent,
|
|
system_prompt,
|
|
input,
|
|
output_schema,
|
|
handoffs,
|
|
context_wrapper,
|
|
run_config,
|
|
)
|
|
|
|
return await cls._get_single_step_result_from_response(
|
|
agent=agent,
|
|
original_input=original_input,
|
|
pre_step_items=generated_items,
|
|
new_response=new_response,
|
|
output_schema=output_schema,
|
|
handoffs=handoffs,
|
|
hooks=hooks,
|
|
context_wrapper=context_wrapper,
|
|
run_config=run_config,
|
|
)
|
|
|
|
@classmethod
|
|
async def _get_single_step_result_from_response(
|
|
cls,
|
|
*,
|
|
agent: Agent[TContext],
|
|
original_input: str | list[TResponseInputItem],
|
|
pre_step_items: list[RunItem],
|
|
new_response: ModelResponse,
|
|
output_schema: AgentOutputSchema | None,
|
|
handoffs: list[Handoff],
|
|
hooks: RunHooks[TContext],
|
|
context_wrapper: RunContextWrapper[TContext],
|
|
run_config: RunConfig,
|
|
) -> SingleStepResult:
|
|
processed_response = RunImpl.process_model_response(
|
|
agent=agent,
|
|
response=new_response,
|
|
output_schema=output_schema,
|
|
handoffs=handoffs,
|
|
)
|
|
return await RunImpl.execute_tools_and_side_effects(
|
|
agent=agent,
|
|
original_input=original_input,
|
|
pre_step_items=pre_step_items,
|
|
new_response=new_response,
|
|
processed_response=processed_response,
|
|
output_schema=output_schema,
|
|
hooks=hooks,
|
|
context_wrapper=context_wrapper,
|
|
run_config=run_config,
|
|
)
|
|
|
|
@classmethod
|
|
async def _run_input_guardrails(
|
|
cls,
|
|
agent: Agent[Any],
|
|
guardrails: list[InputGuardrail[TContext]],
|
|
input: str | list[TResponseInputItem],
|
|
context: RunContextWrapper[TContext],
|
|
) -> list[InputGuardrailResult]:
|
|
if not guardrails:
|
|
return []
|
|
|
|
guardrail_tasks = [
|
|
asyncio.create_task(
|
|
RunImpl.run_single_input_guardrail(agent, guardrail, input, context)
|
|
)
|
|
for guardrail in guardrails
|
|
]
|
|
|
|
guardrail_results = []
|
|
|
|
for done in asyncio.as_completed(guardrail_tasks):
|
|
result = await done
|
|
if result.output.tripwire_triggered:
|
|
# Cancel all guardrail tasks if a tripwire is triggered.
|
|
for t in guardrail_tasks:
|
|
t.cancel()
|
|
_utils.attach_error_to_current_span(
|
|
SpanError(
|
|
message="Guardrail tripwire triggered",
|
|
data={"guardrail": result.guardrail.get_name()},
|
|
)
|
|
)
|
|
raise InputGuardrailTripwireTriggered(result)
|
|
else:
|
|
guardrail_results.append(result)
|
|
|
|
return guardrail_results
|
|
|
|
@classmethod
|
|
async def _run_output_guardrails(
|
|
cls,
|
|
guardrails: list[OutputGuardrail[TContext]],
|
|
agent: Agent[TContext],
|
|
agent_output: Any,
|
|
context: RunContextWrapper[TContext],
|
|
) -> list[OutputGuardrailResult]:
|
|
if not guardrails:
|
|
return []
|
|
|
|
guardrail_tasks = [
|
|
asyncio.create_task(
|
|
RunImpl.run_single_output_guardrail(guardrail, agent, agent_output, context)
|
|
)
|
|
for guardrail in guardrails
|
|
]
|
|
|
|
guardrail_results = []
|
|
|
|
for done in asyncio.as_completed(guardrail_tasks):
|
|
result = await done
|
|
if result.output.tripwire_triggered:
|
|
# Cancel all guardrail tasks if a tripwire is triggered.
|
|
for t in guardrail_tasks:
|
|
t.cancel()
|
|
_utils.attach_error_to_current_span(
|
|
SpanError(
|
|
message="Guardrail tripwire triggered",
|
|
data={"guardrail": result.guardrail.get_name()},
|
|
)
|
|
)
|
|
raise OutputGuardrailTripwireTriggered(result)
|
|
else:
|
|
guardrail_results.append(result)
|
|
|
|
return guardrail_results
|
|
|
|
@classmethod
|
|
async def _get_new_response(
|
|
cls,
|
|
agent: Agent[TContext],
|
|
system_prompt: str | None,
|
|
input: list[TResponseInputItem],
|
|
output_schema: AgentOutputSchema | None,
|
|
handoffs: list[Handoff],
|
|
context_wrapper: RunContextWrapper[TContext],
|
|
run_config: RunConfig,
|
|
) -> ModelResponse:
|
|
model = cls._get_model(agent, run_config)
|
|
model_settings = agent.model_settings.resolve(run_config.model_settings)
|
|
new_response = await model.get_response(
|
|
system_instructions=system_prompt,
|
|
input=input,
|
|
model_settings=model_settings,
|
|
tools=agent.tools,
|
|
output_schema=output_schema,
|
|
handoffs=handoffs,
|
|
tracing=get_model_tracing_impl(
|
|
run_config.tracing_disabled, run_config.trace_include_sensitive_data
|
|
),
|
|
)
|
|
|
|
context_wrapper.usage.add(new_response.usage)
|
|
|
|
return new_response
|
|
|
|
@classmethod
|
|
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchema | None:
|
|
if agent.output_type is None or agent.output_type is str:
|
|
return None
|
|
|
|
return AgentOutputSchema(agent.output_type)
|
|
|
|
@classmethod
|
|
def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
|
|
handoffs = []
|
|
for handoff_item in agent.handoffs:
|
|
if isinstance(handoff_item, Handoff):
|
|
handoffs.append(handoff_item)
|
|
elif isinstance(handoff_item, Agent):
|
|
handoffs.append(handoff(handoff_item))
|
|
return handoffs
|
|
|
|
@classmethod
|
|
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
|
|
if isinstance(run_config.model, Model):
|
|
return run_config.model
|
|
elif isinstance(run_config.model, str):
|
|
return run_config.model_provider.get_model(run_config.model)
|
|
elif isinstance(agent.model, Model):
|
|
return agent.model
|
|
|
|
return run_config.model_provider.get_model(agent.model)
|