Enable non-strict output types (#539)
See #528, some folks are having issues because their output types are not strict-compatible. My approach was: 1. Create `AgentOutputSchemaBase`, which represents the base methods for an output type - the json schema + validation 2. Make the existing `AgentOutputSchema` subclass `AgentOutputSchemaBase` 3. Allow users to pass a `AgentOutputSchemaBase` to `Agent(output_type=...)`
This commit is contained in:
parent
4b8472da7e
commit
e3698f32b1
18 changed files with 256 additions and 61 deletions
81
examples/basic/non_strict_output_type.py
Normal file
81
examples/basic/non_strict_output_type.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agents import Agent, AgentOutputSchema, AgentOutputSchemaBase, Runner
|
||||
|
||||
"""This example demonstrates how to use an output type that is not in strict mode. Strict mode
|
||||
allows us to guarantee valid JSON output, but some schemas are not strict-compatible.
|
||||
|
||||
In this example, we define an output type that is not strict-compatible, and then we run the
|
||||
agent with strict_json_schema=False.
|
||||
|
||||
We also demonstrate a custom output type.
|
||||
|
||||
To understand which schemas are strict-compatible, see:
|
||||
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputType:
|
||||
jokes: dict[int, str]
|
||||
"""A list of jokes, indexed by joke number."""
|
||||
|
||||
|
||||
class CustomOutputSchema(AgentOutputSchemaBase):
|
||||
"""A demonstration of a custom output schema."""
|
||||
|
||||
def is_plain_text(self) -> bool:
|
||||
return False
|
||||
|
||||
def name(self) -> str:
|
||||
return "CustomOutputSchema"
|
||||
|
||||
def json_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {"jokes": {"type": "object", "properties": {"joke": {"type": "string"}}}},
|
||||
}
|
||||
|
||||
def is_strict_json_schema(self) -> bool:
|
||||
return False
|
||||
|
||||
def validate_json(self, json_str: str) -> Any:
|
||||
json_obj = json.loads(json_str)
|
||||
# Just for demonstration, we'll return a list.
|
||||
return list(json_obj["jokes"].values())
|
||||
|
||||
|
||||
async def main():
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions="You are a helpful assistant.",
|
||||
output_type=OutputType,
|
||||
)
|
||||
|
||||
input = "Tell me 3 short jokes."
|
||||
|
||||
# First, let's try with a strict output type. This should raise an exception.
|
||||
try:
|
||||
result = await Runner.run(agent, input)
|
||||
raise AssertionError("Should have raised an exception")
|
||||
except Exception as e:
|
||||
print(f"Error (expected): {e}")
|
||||
|
||||
# Now let's try again with a non-strict output type. This should work.
|
||||
# In some cases, it will raise an error - the schema isn't strict, so the model may
|
||||
# produce an invalid JSON object.
|
||||
agent.output_type = AgentOutputSchema(OutputType, strict_json_schema=False)
|
||||
result = await Runner.run(agent, input)
|
||||
print(result.final_output)
|
||||
|
||||
# Finally, let's try a custom output type.
|
||||
agent.output_type = CustomOutputSchema()
|
||||
result = await Runner.run(agent, input)
|
||||
print(result.final_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -6,7 +6,7 @@ from openai import AsyncOpenAI
|
|||
|
||||
from . import _config
|
||||
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
|
||||
from .computer import AsyncComputer, Button, Computer, Environment
|
||||
from .exceptions import (
|
||||
AgentsException,
|
||||
|
|
@ -158,6 +158,7 @@ __all__ = [
|
|||
"OpenAIProvider",
|
||||
"OpenAIResponsesModel",
|
||||
"AgentOutputSchema",
|
||||
"AgentOutputSchemaBase",
|
||||
"Computer",
|
||||
"AsyncComputer",
|
||||
"Environment",
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from openai.types.responses.response_input_param import ComputerCallOutput
|
|||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
|
||||
from .agent import Agent, ToolsToFinalOutputResult
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .agent_output import AgentOutputSchemaBase
|
||||
from .computer import AsyncComputer, Computer
|
||||
from .exceptions import AgentsException, ModelBehaviorError, UserError
|
||||
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
|
||||
|
|
@ -195,7 +195,7 @@ class RunImpl:
|
|||
pre_step_items: list[RunItem],
|
||||
new_response: ModelResponse,
|
||||
processed_response: ProcessedResponse,
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
hooks: RunHooks[TContext],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
run_config: RunConfig,
|
||||
|
|
@ -335,7 +335,7 @@ class RunImpl:
|
|||
agent: Agent[Any],
|
||||
all_tools: list[Tool],
|
||||
response: ModelResponse,
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
) -> ProcessedResponse:
|
||||
items: list[RunItem] = []
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
|||
|
||||
from typing_extensions import NotRequired, TypeAlias, TypedDict
|
||||
|
||||
from .agent_output import AgentOutputSchemaBase
|
||||
from .guardrail import InputGuardrail, OutputGuardrail
|
||||
from .handoffs import Handoff
|
||||
from .items import ItemHelpers
|
||||
|
|
@ -141,8 +142,14 @@ class Agent(Generic[TContext]):
|
|||
Runs only if the agent produces a final output.
|
||||
"""
|
||||
|
||||
output_type: type[Any] | None = None
|
||||
"""The type of the output object. If not provided, the output will be `str`."""
|
||||
output_type: type[Any] | AgentOutputSchemaBase | None = None
|
||||
"""The type of the output object. If not provided, the output will be `str`. In most cases,
|
||||
you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
|
||||
You can customize this in two ways:
|
||||
1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
|
||||
2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
|
||||
creation, subclass and pass an `AgentOutputSchemaBase` subclass.
|
||||
"""
|
||||
|
||||
hooks: AgentHooks[TContext] | None = None
|
||||
"""A class that receives callbacks on various lifecycle events for this agent.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -12,8 +13,46 @@ from .util import _error_tracing, _json
|
|||
_WRAPPER_DICT_KEY = "response"
|
||||
|
||||
|
||||
class AgentOutputSchemaBase(abc.ABC):
|
||||
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
|
||||
produced by the LLM into the output type.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_plain_text(self) -> bool:
|
||||
"""Whether the output type is plain text (versus a JSON object)."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of the output type."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def json_schema(self) -> dict[str, Any]:
|
||||
"""Returns the JSON schema of the output. Will only be called if the output type is not
|
||||
plain text.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_strict_json_schema(self) -> bool:
|
||||
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
|
||||
features, but guarantees valis JSON. See here for details:
|
||||
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate_json(self, json_str: str) -> Any:
|
||||
"""Validate a JSON string against the output type. You must return the validated object,
|
||||
or raise a `ModelBehaviorError` if the JSON is invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class AgentOutputSchema:
|
||||
class AgentOutputSchema(AgentOutputSchemaBase):
|
||||
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
|
||||
produced by the LLM into the output type.
|
||||
"""
|
||||
|
|
@ -32,7 +71,7 @@ class AgentOutputSchema:
|
|||
_output_schema: dict[str, Any]
|
||||
"""The JSON schema of the output."""
|
||||
|
||||
strict_json_schema: bool
|
||||
_strict_json_schema: bool
|
||||
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
||||
as it increases the likelihood of correct JSON input.
|
||||
"""
|
||||
|
|
@ -45,7 +84,7 @@ class AgentOutputSchema:
|
|||
setting this to True, as it increases the likelihood of correct JSON input.
|
||||
"""
|
||||
self.output_type = output_type
|
||||
self.strict_json_schema = strict_json_schema
|
||||
self._strict_json_schema = strict_json_schema
|
||||
|
||||
if output_type is None or output_type is str:
|
||||
self._is_wrapped = False
|
||||
|
|
@ -70,24 +109,35 @@ class AgentOutputSchema:
|
|||
self._type_adapter = TypeAdapter(output_type)
|
||||
self._output_schema = self._type_adapter.json_schema()
|
||||
|
||||
if self.strict_json_schema:
|
||||
self._output_schema = ensure_strict_json_schema(self._output_schema)
|
||||
if self._strict_json_schema:
|
||||
try:
|
||||
self._output_schema = ensure_strict_json_schema(self._output_schema)
|
||||
except UserError as e:
|
||||
raise UserError(
|
||||
"Strict JSON schema is enabled, but the output type is not valid. "
|
||||
"Either make the output type strict, or pass output_schema_strict=False to "
|
||||
"your Agent()"
|
||||
) from e
|
||||
|
||||
def is_plain_text(self) -> bool:
|
||||
"""Whether the output type is plain text (versus a JSON object)."""
|
||||
return self.output_type is None or self.output_type is str
|
||||
|
||||
def is_strict_json_schema(self) -> bool:
|
||||
"""Whether the JSON schema is in strict mode."""
|
||||
return self._strict_json_schema
|
||||
|
||||
def json_schema(self) -> dict[str, Any]:
|
||||
"""The JSON schema of the output type."""
|
||||
if self.is_plain_text():
|
||||
raise UserError("Output type is plain text, so no JSON schema is available")
|
||||
return self._output_schema
|
||||
|
||||
def validate_json(self, json_str: str, partial: bool = False) -> Any:
|
||||
def validate_json(self, json_str: str) -> Any:
|
||||
"""Validate a JSON string against the output type. Returns the validated object, or raises
|
||||
a `ModelBehaviorError` if the JSON is invalid.
|
||||
"""
|
||||
validated = _json.validate_json(json_str, self._type_adapter, partial)
|
||||
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
|
||||
if self._is_wrapped:
|
||||
if not isinstance(validated, dict):
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
|
|
@ -113,7 +163,7 @@ class AgentOutputSchema:
|
|||
return validated[_WRAPPER_DICT_KEY]
|
||||
return validated
|
||||
|
||||
def output_type_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
"""The name of the output type."""
|
||||
return _type_to_str(self.output_type)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from openai.types.chat.chat_completion_message_tool_call import Function
|
|||
from openai.types.responses import Response
|
||||
|
||||
from ... import _debug
|
||||
from ...agent_output import AgentOutputSchema
|
||||
from ...agent_output import AgentOutputSchemaBase
|
||||
from ...handoffs import Handoff
|
||||
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
||||
from ...logger import logger
|
||||
|
|
@ -68,7 +68,7 @@ class LitellmModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
|
|
@ -139,7 +139,7 @@ class LitellmModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
|
|
@ -186,7 +186,7 @@ class LitellmModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
|
|
@ -200,7 +200,7 @@ class LitellmModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
|
|
@ -213,7 +213,7 @@ class LitellmModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from openai.types.responses import (
|
|||
)
|
||||
from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message
|
||||
|
||||
from ..agent_output import AgentOutputSchema
|
||||
from ..agent_output import AgentOutputSchemaBase
|
||||
from ..exceptions import AgentsException, UserError
|
||||
from ..handoffs import Handoff
|
||||
from ..items import TResponseInputItem, TResponseOutputItem
|
||||
|
|
@ -67,7 +67,7 @@ class Converter:
|
|||
|
||||
@classmethod
|
||||
def convert_response_format(
|
||||
cls, final_output_schema: AgentOutputSchema | None
|
||||
cls, final_output_schema: AgentOutputSchemaBase | None
|
||||
) -> ResponseFormat | NotGiven:
|
||||
if not final_output_schema or final_output_schema.is_plain_text():
|
||||
return NOT_GIVEN
|
||||
|
|
@ -76,7 +76,7 @@ class Converter:
|
|||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "final_output",
|
||||
"strict": final_output_schema.strict_json_schema,
|
||||
"strict": final_output_schema.is_strict_json_schema(),
|
||||
"schema": final_output_schema.json_schema(),
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import enum
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..agent_output import AgentOutputSchema
|
||||
from ..agent_output import AgentOutputSchemaBase
|
||||
from ..handoffs import Handoff
|
||||
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
||||
from ..tool import Tool
|
||||
|
|
@ -41,7 +41,7 @@ class Model(abc.ABC):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
|
|
@ -72,7 +72,7 @@ class Model(abc.ABC):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|||
from openai.types.responses import Response
|
||||
|
||||
from .. import _debug
|
||||
from ..agent_output import AgentOutputSchema
|
||||
from ..agent_output import AgentOutputSchemaBase
|
||||
from ..handoffs import Handoff
|
||||
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
||||
from ..logger import logger
|
||||
|
|
@ -49,7 +49,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
|
|
@ -110,7 +110,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
|
|
@ -160,7 +160,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
|
|
@ -174,7 +174,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
|
|
@ -187,7 +187,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from openai.types.responses import (
|
|||
)
|
||||
|
||||
from .. import _debug
|
||||
from ..agent_output import AgentOutputSchema
|
||||
from ..agent_output import AgentOutputSchemaBase
|
||||
from ..exceptions import UserError
|
||||
from ..handoffs import Handoff
|
||||
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
|
||||
|
|
@ -66,7 +66,7 @@ class OpenAIResponsesModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
|
|
@ -131,7 +131,7 @@ class OpenAIResponsesModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
|
|
@ -182,7 +182,7 @@ class OpenAIResponsesModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
previous_response_id: str | None,
|
||||
stream: Literal[True],
|
||||
|
|
@ -195,7 +195,7 @@ class OpenAIResponsesModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
previous_response_id: str | None,
|
||||
stream: Literal[False],
|
||||
|
|
@ -207,7 +207,7 @@ class OpenAIResponsesModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
previous_response_id: str | None,
|
||||
stream: Literal[True] | Literal[False] = False,
|
||||
|
|
@ -307,7 +307,7 @@ class Converter:
|
|||
|
||||
@classmethod
|
||||
def get_response_format(
|
||||
cls, output_schema: AgentOutputSchema | None
|
||||
cls, output_schema: AgentOutputSchemaBase | None
|
||||
) -> ResponseTextConfigParam | NotGiven:
|
||||
if output_schema is None or output_schema.is_plain_text():
|
||||
return NOT_GIVEN
|
||||
|
|
@ -317,7 +317,7 @@ class Converter:
|
|||
"type": "json_schema",
|
||||
"name": "final_output",
|
||||
"schema": output_schema.json_schema(),
|
||||
"strict": output_schema.strict_json_schema,
|
||||
"strict": output_schema.is_strict_json_schema(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing_extensions import TypeVar
|
|||
|
||||
from ._run_impl import QueueCompleteSentinel
|
||||
from .agent import Agent
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .agent_output import AgentOutputSchemaBase
|
||||
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
|
||||
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
||||
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
||||
|
|
@ -124,7 +124,7 @@ class RunResultStreaming(RunResultBase):
|
|||
final_output: Any
|
||||
"""The final output of the agent. This is None until the agent has finished running."""
|
||||
|
||||
_current_agent_output_schema: AgentOutputSchema | None = field(repr=False)
|
||||
_current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False)
|
||||
|
||||
_trace: Trace | None = field(repr=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from ._run_impl import (
|
|||
get_model_tracing_impl,
|
||||
)
|
||||
from .agent import Agent
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
|
||||
from .exceptions import (
|
||||
AgentsException,
|
||||
InputGuardrailTripwireTriggered,
|
||||
|
|
@ -185,7 +185,7 @@ class Runner:
|
|||
if current_span is None:
|
||||
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
|
||||
if output_schema := cls._get_output_schema(current_agent):
|
||||
output_type_name = output_schema.output_type_name()
|
||||
output_type_name = output_schema.name()
|
||||
else:
|
||||
output_type_name = "str"
|
||||
|
||||
|
|
@ -517,7 +517,7 @@ class Runner:
|
|||
if current_span is None:
|
||||
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
|
||||
if output_schema := cls._get_output_schema(current_agent):
|
||||
output_type_name = output_schema.output_type_name()
|
||||
output_type_name = output_schema.name()
|
||||
else:
|
||||
output_type_name = "str"
|
||||
|
||||
|
|
@ -789,7 +789,7 @@ class Runner:
|
|||
original_input: str | list[TResponseInputItem],
|
||||
pre_step_items: list[RunItem],
|
||||
new_response: ModelResponse,
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
hooks: RunHooks[TContext],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
|
|
@ -900,7 +900,7 @@ class Runner:
|
|||
agent: Agent[TContext],
|
||||
system_prompt: str | None,
|
||||
input: list[TResponseInputItem],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
all_tools: list[Tool],
|
||||
handoffs: list[Handoff],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
|
|
@ -930,9 +930,11 @@ class Runner:
|
|||
return new_response
|
||||
|
||||
@classmethod
|
||||
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchema | None:
|
||||
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
|
||||
if agent.output_type is None or agent.output_type is str:
|
||||
return None
|
||||
elif isinstance(agent.output_type, AgentOutputSchemaBase):
|
||||
return agent.output_type
|
||||
|
||||
return AgentOutputSchema(agent.output_type)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any
|
|||
|
||||
from openai.types.responses import Response, ResponseCompletedEvent
|
||||
|
||||
from agents.agent_output import AgentOutputSchema
|
||||
from agents.agent_output import AgentOutputSchemaBase
|
||||
from agents.handoffs import Handoff
|
||||
from agents.items import (
|
||||
ModelResponse,
|
||||
|
|
@ -51,7 +51,7 @@ class FakeModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
|
|
@ -93,7 +93,7 @@ class FakeModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent, Handoff, RunContextWrapper, Runner, handoff
|
||||
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, Runner, handoff
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -160,8 +160,9 @@ async def test_agent_final_output():
|
|||
)
|
||||
|
||||
schema = Runner._get_output_schema(agent)
|
||||
assert isinstance(schema, AgentOutputSchema)
|
||||
assert schema is not None
|
||||
assert schema.output_type == Foo
|
||||
assert schema.strict_json_schema is True
|
||||
assert schema.is_strict_json_schema() is True
|
||||
assert schema.json_schema() is not None
|
||||
assert not schema.is_plain_text()
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ def test_convert_response_format_returns_not_given_for_plain_text_and_dict_for_s
|
|||
assert resp_format["type"] == "json_schema"
|
||||
assert resp_format["json_schema"]["name"] == "final_output"
|
||||
assert "strict" in resp_format["json_schema"]
|
||||
assert resp_format["json_schema"]["strict"] == schema.strict_json_schema
|
||||
assert resp_format["json_schema"]["strict"] == schema.is_strict_json_schema()
|
||||
assert "schema" in resp_format["json_schema"]
|
||||
assert resp_format["json_schema"]["schema"] == schema.json_schema()
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def test_get_response_format_plain_text_and_json_schema():
|
|||
assert inner.get("name") == "final_output"
|
||||
assert isinstance(inner.get("schema"), dict)
|
||||
# Should include a strict flag matching the schema's strictness setting.
|
||||
assert inner.get("strict") == out_schema.strict_json_schema
|
||||
assert inner.get("strict") == out_schema.is_strict_json_schema()
|
||||
|
||||
|
||||
def test_convert_tools_basic_types_and_includes():
|
||||
|
|
|
|||
|
|
@ -1,10 +1,18 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError
|
||||
from agents import (
|
||||
Agent,
|
||||
AgentOutputSchema,
|
||||
AgentOutputSchemaBase,
|
||||
ModelBehaviorError,
|
||||
Runner,
|
||||
UserError,
|
||||
)
|
||||
from agents.agent_output import _WRAPPER_DICT_KEY
|
||||
from agents.util import _json
|
||||
|
||||
|
|
@ -27,6 +35,7 @@ def test_structured_output_pydantic():
|
|||
output_schema = Runner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
|
||||
assert isinstance(output_schema, AgentOutputSchema)
|
||||
assert output_schema.output_type == Foo, "Should have the correct output type"
|
||||
assert not output_schema._is_wrapped, "Pydantic objects should not be wrapped"
|
||||
for key, value in Foo.model_json_schema().items():
|
||||
|
|
@ -45,6 +54,7 @@ def test_structured_output_typed_dict():
|
|||
agent = Agent(name="test", output_type=Bar)
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
assert isinstance(output_schema, AgentOutputSchema)
|
||||
assert output_schema.output_type == Bar, "Should have the correct output type"
|
||||
assert not output_schema._is_wrapped, "TypedDicts should not be wrapped"
|
||||
|
||||
|
|
@ -57,6 +67,7 @@ def test_structured_output_list():
|
|||
agent = Agent(name="test", output_type=list[str])
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
assert isinstance(output_schema, AgentOutputSchema)
|
||||
assert output_schema.output_type == list[str], "Should have the correct output type"
|
||||
assert output_schema._is_wrapped, "Lists should be wrapped"
|
||||
|
||||
|
|
@ -98,7 +109,7 @@ def test_plain_text_obj_doesnt_produce_schema():
|
|||
|
||||
def test_structured_output_is_strict():
|
||||
output_wrapper = AgentOutputSchema(output_type=Foo)
|
||||
assert output_wrapper.strict_json_schema
|
||||
assert output_wrapper.is_strict_json_schema()
|
||||
for key, value in Foo.model_json_schema().items():
|
||||
assert output_wrapper.json_schema()[key] == value
|
||||
|
||||
|
|
@ -110,6 +121,48 @@ def test_structured_output_is_strict():
|
|||
|
||||
def test_setting_strict_false_works():
|
||||
output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False)
|
||||
assert not output_wrapper.strict_json_schema
|
||||
assert not output_wrapper.is_strict_json_schema()
|
||||
assert output_wrapper.json_schema() == Foo.model_json_schema()
|
||||
assert output_wrapper.json_schema() == Foo.model_json_schema()
|
||||
|
||||
|
||||
_CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"type": "string"},
|
||||
},
|
||||
"required": ["foo"],
|
||||
}
|
||||
|
||||
|
||||
class CustomOutputSchema(AgentOutputSchemaBase):
|
||||
def is_plain_text(self) -> bool:
|
||||
return False
|
||||
|
||||
def name(self) -> str:
|
||||
return "FooBarBaz"
|
||||
|
||||
def json_schema(self) -> dict[str, Any]:
|
||||
return _CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA
|
||||
|
||||
def is_strict_json_schema(self) -> bool:
|
||||
return False
|
||||
|
||||
def validate_json(self, json_str: str) -> Any:
|
||||
return ["some", "output"]
|
||||
|
||||
|
||||
def test_custom_output_schema():
|
||||
custom_output_schema = CustomOutputSchema()
|
||||
agent = Agent(name="test", output_type=custom_output_schema)
|
||||
output_schema = Runner._get_output_schema(agent)
|
||||
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
assert isinstance(output_schema, CustomOutputSchema)
|
||||
assert output_schema.json_schema() == _CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA
|
||||
assert not output_schema.is_strict_json_schema()
|
||||
assert not output_schema.is_plain_text()
|
||||
|
||||
json_str = json.dumps({"foo": "bar"})
|
||||
validated = output_schema.validate_json(json_str)
|
||||
assert validated == ["some", "output"]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from openai.types.responses import ResponseCompletedEvent
|
|||
from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
|
||||
|
||||
from agents import Agent, Model, ModelSettings, ModelTracing, Tool
|
||||
from agents.agent_output import AgentOutputSchema
|
||||
from agents.agent_output import AgentOutputSchemaBase
|
||||
from agents.handoffs import Handoff
|
||||
from agents.items import (
|
||||
ModelResponse,
|
||||
|
|
@ -48,7 +48,7 @@ class FakeStreamingModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
|
|
@ -62,7 +62,7 @@ class FakeStreamingModel(Model):
|
|||
input: str | list[TResponseInputItem],
|
||||
model_settings: ModelSettings,
|
||||
tools: list[Tool],
|
||||
output_schema: AgentOutputSchema | None,
|
||||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
|
|
|
|||
Loading…
Reference in a new issue