144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, TypeAdapter
|
|
from typing_extensions import TypedDict, get_args, get_origin
|
|
|
|
from . import _utils
|
|
from .exceptions import ModelBehaviorError, UserError
|
|
from .strict_schema import ensure_strict_json_schema
|
|
from .tracing import SpanError
|
|
|
|
_WRAPPER_DICT_KEY = "response"
|
|
|
|
|
|
@dataclass(init=False)
|
|
class AgentOutputSchema:
|
|
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
|
|
produced by the LLM into the output type.
|
|
"""
|
|
|
|
output_type: type[Any]
|
|
"""The type of the output."""
|
|
|
|
_type_adapter: TypeAdapter[Any]
|
|
"""A type adapter that wraps the output type, so that we can validate JSON."""
|
|
|
|
_is_wrapped: bool
|
|
"""Whether the output type is wrapped in a dictionary. This is generally done if the base
|
|
output type cannot be represented as a JSON Schema object.
|
|
"""
|
|
|
|
_output_schema: dict[str, Any]
|
|
"""The JSON schema of the output."""
|
|
|
|
strict_json_schema: bool
|
|
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
|
as it increases the likelihood of correct JSON input.
|
|
"""
|
|
|
|
def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
|
|
"""
|
|
Args:
|
|
output_type: The type of the output.
|
|
strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend
|
|
setting this to True, as it increases the likelihood of correct JSON input.
|
|
"""
|
|
self.output_type = output_type
|
|
self.strict_json_schema = strict_json_schema
|
|
|
|
if output_type is None or output_type is str:
|
|
self._is_wrapped = False
|
|
self._type_adapter = TypeAdapter(output_type)
|
|
self._output_schema = self._type_adapter.json_schema()
|
|
return
|
|
|
|
# We should wrap for things that are not plain text, and for things that would definitely
|
|
# not be a JSON Schema object.
|
|
self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type)
|
|
|
|
if self._is_wrapped:
|
|
OutputType = TypedDict(
|
|
"OutputType",
|
|
{
|
|
_WRAPPER_DICT_KEY: output_type, # type: ignore
|
|
},
|
|
)
|
|
self._type_adapter = TypeAdapter(OutputType)
|
|
self._output_schema = self._type_adapter.json_schema()
|
|
else:
|
|
self._type_adapter = TypeAdapter(output_type)
|
|
self._output_schema = self._type_adapter.json_schema()
|
|
|
|
if self.strict_json_schema:
|
|
self._output_schema = ensure_strict_json_schema(self._output_schema)
|
|
|
|
def is_plain_text(self) -> bool:
|
|
"""Whether the output type is plain text (versus a JSON object)."""
|
|
return self.output_type is None or self.output_type is str
|
|
|
|
def json_schema(self) -> dict[str, Any]:
|
|
"""The JSON schema of the output type."""
|
|
if self.is_plain_text():
|
|
raise UserError("Output type is plain text, so no JSON schema is available")
|
|
return self._output_schema
|
|
|
|
def validate_json(self, json_str: str, partial: bool = False) -> Any:
|
|
"""Validate a JSON string against the output type. Returns the validated object, or raises
|
|
a `ModelBehaviorError` if the JSON is invalid.
|
|
"""
|
|
validated = _utils.validate_json(json_str, self._type_adapter, partial)
|
|
if self._is_wrapped:
|
|
if not isinstance(validated, dict):
|
|
_utils.attach_error_to_current_span(
|
|
SpanError(
|
|
message="Invalid JSON",
|
|
data={"details": f"Expected a dict, got {type(validated)}"},
|
|
)
|
|
)
|
|
raise ModelBehaviorError(
|
|
f"Expected a dict, got {type(validated)} for JSON: {json_str}"
|
|
)
|
|
|
|
if _WRAPPER_DICT_KEY not in validated:
|
|
_utils.attach_error_to_current_span(
|
|
SpanError(
|
|
message="Invalid JSON",
|
|
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},
|
|
)
|
|
)
|
|
raise ModelBehaviorError(
|
|
f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}"
|
|
)
|
|
return validated[_WRAPPER_DICT_KEY]
|
|
return validated
|
|
|
|
def output_type_name(self) -> str:
|
|
"""The name of the output type."""
|
|
return _type_to_str(self.output_type)
|
|
|
|
|
|
def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
|
|
if not isinstance(t, type):
|
|
return False
|
|
|
|
# If it's a generic alias, 'origin' will be the actual type, e.g. 'list'
|
|
origin = get_origin(t)
|
|
|
|
allowed_types = (BaseModel, dict)
|
|
# If it's a generic alias e.g. list[str], then we should check the origin type i.e. list
|
|
return issubclass(origin or t, allowed_types)
|
|
|
|
|
|
def _type_to_str(t: type[Any]) -> str:
|
|
origin = get_origin(t)
|
|
args = get_args(t)
|
|
|
|
if origin is None:
|
|
# It's a simple type like `str`, `int`, etc.
|
|
return t.__name__
|
|
elif args:
|
|
args_str = ", ".join(_type_to_str(arg) for arg in args)
|
|
return f"{origin.__name__}[{args_str}]"
|
|
else:
|
|
return str(t)
|