Enable non-strict output types (#539)

See #528, some folks are having issues because their output types are
not strict-compatible.

My approach was:
1. Create `AgentOutputSchemaBase`, which represents the base methods for
an output type - the json schema + validation
2. Make the existing `AgentOutputSchema` subclass
`AgentOutputSchemaBase`
3. Allow users to pass a `AgentOutputSchemaBase` to
`Agent(output_type=...)`
This commit is contained in:
Rohan Mehta 2025-04-21 11:58:36 -04:00 committed by GitHub
parent 4b8472da7e
commit e3698f32b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 256 additions and 61 deletions

View file

@ -0,0 +1,81 @@
import asyncio
import json
from dataclasses import dataclass
from typing import Any
from agents import Agent, AgentOutputSchema, AgentOutputSchemaBase, Runner
"""This example demonstrates how to use an output type that is not in strict mode. Strict mode
allows us to guarantee valid JSON output, but some schemas are not strict-compatible.
In this example, we define an output type that is not strict-compatible, and then we run the
agent with strict_json_schema=False.
We also demonstrate a custom output type.
To understand which schemas are strict-compatible, see:
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
"""
@dataclass
class OutputType:
jokes: dict[int, str]
"""A list of jokes, indexed by joke number."""
class CustomOutputSchema(AgentOutputSchemaBase):
"""A demonstration of a custom output schema."""
def is_plain_text(self) -> bool:
return False
def name(self) -> str:
return "CustomOutputSchema"
def json_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {"jokes": {"type": "object", "properties": {"joke": {"type": "string"}}}},
}
def is_strict_json_schema(self) -> bool:
return False
def validate_json(self, json_str: str) -> Any:
json_obj = json.loads(json_str)
# Just for demonstration, we'll return a list.
return list(json_obj["jokes"].values())
async def main():
agent = Agent(
name="Assistant",
instructions="You are a helpful assistant.",
output_type=OutputType,
)
input = "Tell me 3 short jokes."
# First, let's try with a strict output type. This should raise an exception.
try:
result = await Runner.run(agent, input)
raise AssertionError("Should have raised an exception")
except Exception as e:
print(f"Error (expected): {e}")
# Now let's try again with a non-strict output type. This should work.
# In some cases, it will raise an error - the schema isn't strict, so the model may
# produce an invalid JSON object.
agent.output_type = AgentOutputSchema(OutputType, strict_json_schema=False)
result = await Runner.run(agent, input)
print(result.final_output)
# Finally, let's try a custom output type.
agent.output_type = CustomOutputSchema()
result = await Runner.run(agent, input)
print(result.final_output)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -6,7 +6,7 @@ from openai import AsyncOpenAI
from . import _config
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
from .computer import AsyncComputer, Button, Computer, Environment
from .exceptions import (
AgentsException,
@ -158,6 +158,7 @@ __all__ = [
"OpenAIProvider",
"OpenAIResponsesModel",
"AgentOutputSchema",
"AgentOutputSchemaBase",
"Computer",
"AsyncComputer",
"Environment",

View file

@ -29,7 +29,7 @@ from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from .agent import Agent, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema
from .agent_output import AgentOutputSchemaBase
from .computer import AsyncComputer, Computer
from .exceptions import AgentsException, ModelBehaviorError, UserError
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
@ -195,7 +195,7 @@ class RunImpl:
pre_step_items: list[RunItem],
new_response: ModelResponse,
processed_response: ProcessedResponse,
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
@ -335,7 +335,7 @@ class RunImpl:
agent: Agent[Any],
all_tools: list[Tool],
response: ModelResponse,
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
) -> ProcessedResponse:
items: list[RunItem] = []

View file

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
from typing_extensions import NotRequired, TypeAlias, TypedDict
from .agent_output import AgentOutputSchemaBase
from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
from .items import ItemHelpers
@ -141,8 +142,14 @@ class Agent(Generic[TContext]):
Runs only if the agent produces a final output.
"""
output_type: type[Any] | None = None
"""The type of the output object. If not provided, the output will be `str`."""
output_type: type[Any] | AgentOutputSchemaBase | None = None
"""The type of the output object. If not provided, the output will be `str`. In most cases,
you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
You can customize this in two ways:
1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
creation, subclass and pass an `AgentOutputSchemaBase` subclass.
"""
hooks: AgentHooks[TContext] | None = None
"""A class that receives callbacks on various lifecycle events for this agent.

View file

@ -1,3 +1,4 @@
import abc
from dataclasses import dataclass
from typing import Any
@ -12,8 +13,46 @@ from .util import _error_tracing, _json
_WRAPPER_DICT_KEY = "response"
class AgentOutputSchemaBase(abc.ABC):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
@abc.abstractmethod
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
pass
@abc.abstractmethod
def name(self) -> str:
"""The name of the output type."""
pass
@abc.abstractmethod
def json_schema(self) -> dict[str, Any]:
"""Returns the JSON schema of the output. Will only be called if the output type is not
plain text.
"""
pass
@abc.abstractmethod
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
features, but guarantees valis JSON. See here for details:
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
"""
pass
@abc.abstractmethod
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. You must return the validated object,
or raise a `ModelBehaviorError` if the JSON is invalid.
"""
pass
@dataclass(init=False)
class AgentOutputSchema:
class AgentOutputSchema(AgentOutputSchemaBase):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
@ -32,7 +71,7 @@ class AgentOutputSchema:
_output_schema: dict[str, Any]
"""The JSON schema of the output."""
strict_json_schema: bool
_strict_json_schema: bool
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input.
"""
@ -45,7 +84,7 @@ class AgentOutputSchema:
setting this to True, as it increases the likelihood of correct JSON input.
"""
self.output_type = output_type
self.strict_json_schema = strict_json_schema
self._strict_json_schema = strict_json_schema
if output_type is None or output_type is str:
self._is_wrapped = False
@ -70,24 +109,35 @@ class AgentOutputSchema:
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
if self.strict_json_schema:
self._output_schema = ensure_strict_json_schema(self._output_schema)
if self._strict_json_schema:
try:
self._output_schema = ensure_strict_json_schema(self._output_schema)
except UserError as e:
raise UserError(
"Strict JSON schema is enabled, but the output type is not valid. "
"Either make the output type strict, or pass output_schema_strict=False to "
"your Agent()"
) from e
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
return self.output_type is None or self.output_type is str
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode."""
return self._strict_json_schema
def json_schema(self) -> dict[str, Any]:
"""The JSON schema of the output type."""
if self.is_plain_text():
raise UserError("Output type is plain text, so no JSON schema is available")
return self._output_schema
def validate_json(self, json_str: str, partial: bool = False) -> Any:
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _json.validate_json(json_str, self._type_adapter, partial)
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
if self._is_wrapped:
if not isinstance(validated, dict):
_error_tracing.attach_error_to_current_span(
@ -113,7 +163,7 @@ class AgentOutputSchema:
return validated[_WRAPPER_DICT_KEY]
return validated
def output_type_name(self) -> str:
def name(self) -> str:
"""The name of the output type."""
return _type_to_str(self.output_type)

View file

@ -29,7 +29,7 @@ from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.responses import Response
from ... import _debug
from ...agent_output import AgentOutputSchema
from ...agent_output import AgentOutputSchemaBase
from ...handoffs import Handoff
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
from ...logger import logger
@ -68,7 +68,7 @@ class LitellmModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: str | None,
@ -139,7 +139,7 @@ class LitellmModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
@ -186,7 +186,7 @@ class LitellmModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,
@ -200,7 +200,7 @@ class LitellmModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,
@ -213,7 +213,7 @@ class LitellmModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,

View file

@ -36,7 +36,7 @@ from openai.types.responses import (
)
from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message
from ..agent_output import AgentOutputSchema
from ..agent_output import AgentOutputSchemaBase
from ..exceptions import AgentsException, UserError
from ..handoffs import Handoff
from ..items import TResponseInputItem, TResponseOutputItem
@ -67,7 +67,7 @@ class Converter:
@classmethod
def convert_response_format(
cls, final_output_schema: AgentOutputSchema | None
cls, final_output_schema: AgentOutputSchemaBase | None
) -> ResponseFormat | NotGiven:
if not final_output_schema or final_output_schema.is_plain_text():
return NOT_GIVEN
@ -76,7 +76,7 @@ class Converter:
"type": "json_schema",
"json_schema": {
"name": "final_output",
"strict": final_output_schema.strict_json_schema,
"strict": final_output_schema.is_strict_json_schema(),
"schema": final_output_schema.json_schema(),
},
}

View file

@ -5,7 +5,7 @@ import enum
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
from ..agent_output import AgentOutputSchema
from ..agent_output import AgentOutputSchemaBase
from ..handoffs import Handoff
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
from ..tool import Tool
@ -41,7 +41,7 @@ class Model(abc.ABC):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
@ -72,7 +72,7 @@ class Model(abc.ABC):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,

View file

@ -12,7 +12,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.responses import Response
from .. import _debug
from ..agent_output import AgentOutputSchema
from ..agent_output import AgentOutputSchemaBase
from ..handoffs import Handoff
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
from ..logger import logger
@ -49,7 +49,7 @@ class OpenAIChatCompletionsModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: str | None,
@ -110,7 +110,7 @@ class OpenAIChatCompletionsModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
@ -160,7 +160,7 @@ class OpenAIChatCompletionsModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,
@ -174,7 +174,7 @@ class OpenAIChatCompletionsModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,
@ -187,7 +187,7 @@ class OpenAIChatCompletionsModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,

View file

@ -18,7 +18,7 @@ from openai.types.responses import (
)
from .. import _debug
from ..agent_output import AgentOutputSchema
from ..agent_output import AgentOutputSchemaBase
from ..exceptions import UserError
from ..handoffs import Handoff
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
@ -66,7 +66,7 @@ class OpenAIResponsesModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: str | None,
@ -131,7 +131,7 @@ class OpenAIResponsesModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: str | None,
@ -182,7 +182,7 @@ class OpenAIResponsesModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
previous_response_id: str | None,
stream: Literal[True],
@ -195,7 +195,7 @@ class OpenAIResponsesModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
previous_response_id: str | None,
stream: Literal[False],
@ -207,7 +207,7 @@ class OpenAIResponsesModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
previous_response_id: str | None,
stream: Literal[True] | Literal[False] = False,
@ -307,7 +307,7 @@ class Converter:
@classmethod
def get_response_format(
cls, output_schema: AgentOutputSchema | None
cls, output_schema: AgentOutputSchemaBase | None
) -> ResponseTextConfigParam | NotGiven:
if output_schema is None or output_schema.is_plain_text():
return NOT_GIVEN
@ -317,7 +317,7 @@ class Converter:
"type": "json_schema",
"name": "final_output",
"schema": output_schema.json_schema(),
"strict": output_schema.strict_json_schema,
"strict": output_schema.is_strict_json_schema(),
}
}

View file

@ -10,7 +10,7 @@ from typing_extensions import TypeVar
from ._run_impl import QueueCompleteSentinel
from .agent import Agent
from .agent_output import AgentOutputSchema
from .agent_output import AgentOutputSchemaBase
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
@ -124,7 +124,7 @@ class RunResultStreaming(RunResultBase):
final_output: Any
"""The final output of the agent. This is None until the agent has finished running."""
_current_agent_output_schema: AgentOutputSchema | None = field(repr=False)
_current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False)
_trace: Trace | None = field(repr=False)

View file

@ -19,7 +19,7 @@ from ._run_impl import (
get_model_tracing_impl,
)
from .agent import Agent
from .agent_output import AgentOutputSchema
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
from .exceptions import (
AgentsException,
InputGuardrailTripwireTriggered,
@ -185,7 +185,7 @@ class Runner:
if current_span is None:
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
if output_schema := cls._get_output_schema(current_agent):
output_type_name = output_schema.output_type_name()
output_type_name = output_schema.name()
else:
output_type_name = "str"
@ -517,7 +517,7 @@ class Runner:
if current_span is None:
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
if output_schema := cls._get_output_schema(current_agent):
output_type_name = output_schema.output_type_name()
output_type_name = output_schema.name()
else:
output_type_name = "str"
@ -789,7 +789,7 @@ class Runner:
original_input: str | list[TResponseInputItem],
pre_step_items: list[RunItem],
new_response: ModelResponse,
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
@ -900,7 +900,7 @@ class Runner:
agent: Agent[TContext],
system_prompt: str | None,
input: list[TResponseInputItem],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
all_tools: list[Tool],
handoffs: list[Handoff],
context_wrapper: RunContextWrapper[TContext],
@ -930,9 +930,11 @@ class Runner:
return new_response
@classmethod
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchema | None:
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
if agent.output_type is None or agent.output_type is str:
return None
elif isinstance(agent.output_type, AgentOutputSchemaBase):
return agent.output_type
return AgentOutputSchema(agent.output_type)

View file

@ -5,7 +5,7 @@ from typing import Any
from openai.types.responses import Response, ResponseCompletedEvent
from agents.agent_output import AgentOutputSchema
from agents.agent_output import AgentOutputSchemaBase
from agents.handoffs import Handoff
from agents.items import (
ModelResponse,
@ -51,7 +51,7 @@ class FakeModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
@ -93,7 +93,7 @@ class FakeModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,

View file

@ -1,7 +1,7 @@
import pytest
from pydantic import BaseModel
from agents import Agent, Handoff, RunContextWrapper, Runner, handoff
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, Runner, handoff
@pytest.mark.asyncio
@ -160,8 +160,9 @@ async def test_agent_final_output():
)
schema = Runner._get_output_schema(agent)
assert isinstance(schema, AgentOutputSchema)
assert schema is not None
assert schema.output_type == Foo
assert schema.strict_json_schema is True
assert schema.is_strict_json_schema() is True
assert schema.json_schema() is not None
assert not schema.is_plain_text()

View file

@ -232,7 +232,7 @@ def test_convert_response_format_returns_not_given_for_plain_text_and_dict_for_s
assert resp_format["type"] == "json_schema"
assert resp_format["json_schema"]["name"] == "final_output"
assert "strict" in resp_format["json_schema"]
assert resp_format["json_schema"]["strict"] == schema.strict_json_schema
assert resp_format["json_schema"]["strict"] == schema.is_strict_json_schema()
assert "schema" in resp_format["json_schema"]
assert resp_format["json_schema"]["schema"] == schema.json_schema()

View file

@ -92,7 +92,7 @@ def test_get_response_format_plain_text_and_json_schema():
assert inner.get("name") == "final_output"
assert isinstance(inner.get("schema"), dict)
# Should include a strict flag matching the schema's strictness setting.
assert inner.get("strict") == out_schema.strict_json_schema
assert inner.get("strict") == out_schema.is_strict_json_schema()
def test_convert_tools_basic_types_and_includes():

View file

@ -1,10 +1,18 @@
import json
from typing import Any
import pytest
from pydantic import BaseModel
from typing_extensions import TypedDict
from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError
from agents import (
Agent,
AgentOutputSchema,
AgentOutputSchemaBase,
ModelBehaviorError,
Runner,
UserError,
)
from agents.agent_output import _WRAPPER_DICT_KEY
from agents.util import _json
@ -27,6 +35,7 @@ def test_structured_output_pydantic():
output_schema = Runner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type"
assert isinstance(output_schema, AgentOutputSchema)
assert output_schema.output_type == Foo, "Should have the correct output type"
assert not output_schema._is_wrapped, "Pydantic objects should not be wrapped"
for key, value in Foo.model_json_schema().items():
@ -45,6 +54,7 @@ def test_structured_output_typed_dict():
agent = Agent(name="test", output_type=Bar)
output_schema = Runner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type"
assert isinstance(output_schema, AgentOutputSchema)
assert output_schema.output_type == Bar, "Should have the correct output type"
assert not output_schema._is_wrapped, "TypedDicts should not be wrapped"
@ -57,6 +67,7 @@ def test_structured_output_list():
agent = Agent(name="test", output_type=list[str])
output_schema = Runner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type"
assert isinstance(output_schema, AgentOutputSchema)
assert output_schema.output_type == list[str], "Should have the correct output type"
assert output_schema._is_wrapped, "Lists should be wrapped"
@ -98,7 +109,7 @@ def test_plain_text_obj_doesnt_produce_schema():
def test_structured_output_is_strict():
output_wrapper = AgentOutputSchema(output_type=Foo)
assert output_wrapper.strict_json_schema
assert output_wrapper.is_strict_json_schema()
for key, value in Foo.model_json_schema().items():
assert output_wrapper.json_schema()[key] == value
@ -110,6 +121,48 @@ def test_structured_output_is_strict():
def test_setting_strict_false_works():
output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False)
assert not output_wrapper.strict_json_schema
assert not output_wrapper.is_strict_json_schema()
assert output_wrapper.json_schema() == Foo.model_json_schema()
assert output_wrapper.json_schema() == Foo.model_json_schema()
_CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA = {
"type": "object",
"properties": {
"foo": {"type": "string"},
},
"required": ["foo"],
}
class CustomOutputSchema(AgentOutputSchemaBase):
def is_plain_text(self) -> bool:
return False
def name(self) -> str:
return "FooBarBaz"
def json_schema(self) -> dict[str, Any]:
return _CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA
def is_strict_json_schema(self) -> bool:
return False
def validate_json(self, json_str: str) -> Any:
return ["some", "output"]
def test_custom_output_schema():
custom_output_schema = CustomOutputSchema()
agent = Agent(name="test", output_type=custom_output_schema)
output_schema = Runner._get_output_schema(agent)
assert output_schema, "Should have an output tool config with a structured output type"
assert isinstance(output_schema, CustomOutputSchema)
assert output_schema.json_schema() == _CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA
assert not output_schema.is_strict_json_schema()
assert not output_schema.is_plain_text()
json_str = json.dumps({"foo": "bar"})
validated = output_schema.validate_json(json_str)
assert validated == ["some", "output"]

View file

@ -9,7 +9,7 @@ from openai.types.responses import ResponseCompletedEvent
from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
from agents import Agent, Model, ModelSettings, ModelTracing, Tool
from agents.agent_output import AgentOutputSchema
from agents.agent_output import AgentOutputSchemaBase
from agents.handoffs import Handoff
from agents.items import (
ModelResponse,
@ -48,7 +48,7 @@ class FakeStreamingModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
@ -62,7 +62,7 @@ class FakeStreamingModel(Model):
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,