## Summary: #263 added this behavior. The goal was to prevent infinite loops when tool choice was set. The key change I'm making is: 1. Making it configurable on the agent. 2. Doing bookkeeping in the Runner to track this, to prevent mutating agents. 3. Not resetting the global tool choice in RunConfig. ## Test Plan: Unit tests. .
128 lines
4 KiB
Python
128 lines
4 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
from openai.types.responses import Response, ResponseCompletedEvent
|
|
|
|
from agents.agent_output import AgentOutputSchema
|
|
from agents.handoffs import Handoff
|
|
from agents.items import (
|
|
ModelResponse,
|
|
TResponseInputItem,
|
|
TResponseOutputItem,
|
|
TResponseStreamEvent,
|
|
)
|
|
from agents.model_settings import ModelSettings
|
|
from agents.models.interface import Model, ModelTracing
|
|
from agents.tool import Tool
|
|
from agents.tracing import SpanError, generation_span
|
|
from agents.usage import Usage
|
|
|
|
|
|
class FakeModel(Model):
|
|
def __init__(
|
|
self,
|
|
tracing_enabled: bool = False,
|
|
initial_output: list[TResponseOutputItem] | Exception | None = None,
|
|
):
|
|
if initial_output is None:
|
|
initial_output = []
|
|
self.turn_outputs: list[list[TResponseOutputItem] | Exception] = (
|
|
[initial_output] if initial_output else []
|
|
)
|
|
self.tracing_enabled = tracing_enabled
|
|
self.last_turn_args: dict[str, Any] = {}
|
|
|
|
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
|
|
self.turn_outputs.append(output)
|
|
|
|
def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem] | Exception]):
|
|
self.turn_outputs.extend(outputs)
|
|
|
|
def get_next_output(self) -> list[TResponseOutputItem] | Exception:
|
|
if not self.turn_outputs:
|
|
return []
|
|
return self.turn_outputs.pop(0)
|
|
|
|
async def get_response(
|
|
self,
|
|
system_instructions: str | None,
|
|
input: str | list[TResponseInputItem],
|
|
model_settings: ModelSettings,
|
|
tools: list[Tool],
|
|
output_schema: AgentOutputSchema | None,
|
|
handoffs: list[Handoff],
|
|
tracing: ModelTracing,
|
|
) -> ModelResponse:
|
|
self.last_turn_args = {
|
|
"system_instructions": system_instructions,
|
|
"input": input,
|
|
"model_settings": model_settings,
|
|
"tools": tools,
|
|
"output_schema": output_schema,
|
|
}
|
|
|
|
with generation_span(disabled=not self.tracing_enabled) as span:
|
|
output = self.get_next_output()
|
|
|
|
if isinstance(output, Exception):
|
|
span.set_error(
|
|
SpanError(
|
|
message="Error",
|
|
data={
|
|
"name": output.__class__.__name__,
|
|
"message": str(output),
|
|
},
|
|
)
|
|
)
|
|
raise output
|
|
|
|
return ModelResponse(
|
|
output=output,
|
|
usage=Usage(),
|
|
referenceable_id=None,
|
|
)
|
|
|
|
async def stream_response(
|
|
self,
|
|
system_instructions: str | None,
|
|
input: str | list[TResponseInputItem],
|
|
model_settings: ModelSettings,
|
|
tools: list[Tool],
|
|
output_schema: AgentOutputSchema | None,
|
|
handoffs: list[Handoff],
|
|
tracing: ModelTracing,
|
|
) -> AsyncIterator[TResponseStreamEvent]:
|
|
with generation_span(disabled=not self.tracing_enabled) as span:
|
|
output = self.get_next_output()
|
|
if isinstance(output, Exception):
|
|
span.set_error(
|
|
SpanError(
|
|
message="Error",
|
|
data={
|
|
"name": output.__class__.__name__,
|
|
"message": str(output),
|
|
},
|
|
)
|
|
)
|
|
raise output
|
|
|
|
yield ResponseCompletedEvent(
|
|
type="response.completed",
|
|
response=get_response_obj(output),
|
|
)
|
|
|
|
|
|
def get_response_obj(output: list[TResponseOutputItem], response_id: str | None = None) -> Response:
|
|
return Response(
|
|
id=response_id or "123",
|
|
created_at=123,
|
|
model="test_model",
|
|
object="response",
|
|
output=output,
|
|
tool_choice="none",
|
|
tools=[],
|
|
top_p=None,
|
|
parallel_tool_calls=False,
|
|
)
|