from __future__ import annotations from collections.abc import AsyncIterator from openai.types.responses import Response, ResponseCompletedEvent from agents.agent_output import AgentOutputSchema from agents.handoffs import Handoff from agents.items import ( ModelResponse, TResponseInputItem, TResponseOutputItem, TResponseStreamEvent, ) from agents.model_settings import ModelSettings from agents.models.interface import Model, ModelTracing from agents.tool import Tool from agents.tracing import SpanError, generation_span from agents.usage import Usage class FakeModel(Model): def __init__( self, tracing_enabled: bool = False, initial_output: list[TResponseOutputItem] | Exception | None = None, ): if initial_output is None: initial_output = [] self.turn_outputs: list[list[TResponseOutputItem] | Exception] = ( [initial_output] if initial_output else [] ) self.tracing_enabled = tracing_enabled def set_next_output(self, output: list[TResponseOutputItem] | Exception): self.turn_outputs.append(output) def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem] | Exception]): self.turn_outputs.extend(outputs) def get_next_output(self) -> list[TResponseOutputItem] | Exception: if not self.turn_outputs: return [] return self.turn_outputs.pop(0) async def get_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], output_schema: AgentOutputSchema | None, handoffs: list[Handoff], tracing: ModelTracing, ) -> ModelResponse: with generation_span(disabled=not self.tracing_enabled) as span: output = self.get_next_output() if isinstance(output, Exception): span.set_error( SpanError( message="Error", data={ "name": output.__class__.__name__, "message": str(output), }, ) ) raise output return ModelResponse( output=output, usage=Usage(), referenceable_id=None, ) async def stream_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], output_schema: AgentOutputSchema | None, handoffs: list[Handoff], tracing: ModelTracing, ) -> AsyncIterator[TResponseStreamEvent]: with generation_span(disabled=not self.tracing_enabled) as span: output = self.get_next_output() if isinstance(output, Exception): span.set_error( SpanError( message="Error", data={ "name": output.__class__.__name__, "message": str(output), }, ) ) raise output yield ResponseCompletedEvent( type="response.completed", response=get_response_obj(output), ) def get_response_obj(output: list[TResponseOutputItem], response_id: str | None = None) -> Response: return Response( id=response_id or "123", created_at=123, model="test_model", object="response", output=output, tool_choice="none", tools=[], top_p=None, parallel_tool_calls=False, )