openai-agents-python/tests/test_function_schema.py
Rohan Mehta 6d2806f10c
Fix function_schema name override bug (#872)
## Summary
- ensure `name_override` is always used in `function_schema`
- test name override when docstring info is disabled

## Testing
- `make format`
- `make lint`
- `make mypy`
- `make tests`

Resolves #860
------
https://chatgpt.com/codex/tasks/task_i_684f1cf885b08321b4dd3f4294e24ca2
2025-06-16 10:50:20 -04:00

453 lines
15 KiB
Python

from collections.abc import Mapping
from enum import Enum
from typing import Any, Literal
import pytest
from pydantic import BaseModel, ValidationError
from typing_extensions import TypedDict
from agents import RunContextWrapper
from agents.exceptions import UserError
from agents.function_schema import function_schema
def no_args_function():
"""This function has no args."""
return "ok"
def test_no_args_function():
func_schema = function_schema(no_args_function)
assert func_schema.params_json_schema.get("title") == "no_args_function_args"
assert func_schema.description == "This function has no args."
assert not func_schema.takes_context
parsed = func_schema.params_pydantic_model()
args, kwargs_dict = func_schema.to_call_args(parsed)
result = no_args_function(*args, **kwargs_dict)
assert result == "ok"
def no_args_function_with_context(ctx: RunContextWrapper[str]):
return "ok"
def test_no_args_function_with_context() -> None:
func_schema = function_schema(no_args_function_with_context)
assert func_schema.takes_context
context = RunContextWrapper(context="test")
parsed = func_schema.params_pydantic_model()
args, kwargs_dict = func_schema.to_call_args(parsed)
result = no_args_function_with_context(context, *args, **kwargs_dict)
assert result == "ok"
def simple_function(a: int, b: int = 5):
"""
Args:
a: The first argument
b: The second argument
Returns:
The sum of a and b
"""
return a + b
def test_simple_function():
"""Test a function that has simple typed parameters and defaults."""
func_schema = function_schema(simple_function)
# Check that the JSON schema is a dictionary with title, type, etc.
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "simple_function_args"
assert (
func_schema.params_json_schema.get("properties", {}).get("a").get("description")
== "The first argument"
)
assert (
func_schema.params_json_schema.get("properties", {}).get("b").get("description")
== "The second argument"
)
assert not func_schema.takes_context
# Valid input
valid_input = {"a": 3}
parsed = func_schema.params_pydantic_model(**valid_input)
args_tuple, kwargs_dict = func_schema.to_call_args(parsed)
result = simple_function(*args_tuple, **kwargs_dict)
assert result == 8 # 3 + 5
# Another valid input
valid_input2 = {"a": 3, "b": 10}
parsed2 = func_schema.params_pydantic_model(**valid_input2)
args_tuple2, kwargs_dict2 = func_schema.to_call_args(parsed2)
result2 = simple_function(*args_tuple2, **kwargs_dict2)
assert result2 == 13 # 3 + 10
# Invalid input: 'a' must be int
with pytest.raises(ValidationError):
func_schema.params_pydantic_model(**{"a": "not an integer"})
def varargs_function(x: int, *numbers: float, flag: bool = False, **kwargs: Any):
return x, numbers, flag, kwargs
def test_varargs_function():
"""Test a function that uses *args and **kwargs."""
func_schema = function_schema(varargs_function, strict_json_schema=False)
# Check JSON schema structure
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "varargs_function_args"
# Valid input including *args in 'numbers' and **kwargs in 'kwargs'
valid_input = {
"x": 10,
"numbers": [1.1, 2.2, 3.3],
"flag": True,
"kwargs": {"extra1": "hello", "extra2": 42},
}
parsed = func_schema.params_pydantic_model(**valid_input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = varargs_function(*args, **kwargs_dict)
# result should be (10, (1.1, 2.2, 3.3), True, {"extra1": "hello", "extra2": 42})
assert result[0] == 10
assert result[1] == (1.1, 2.2, 3.3)
assert result[2] is True
assert result[3] == {"extra1": "hello", "extra2": 42}
# Missing 'x' should raise error
with pytest.raises(ValidationError):
func_schema.params_pydantic_model(**{"numbers": [1.1, 2.2]})
# 'flag' can be omitted because it has a default
valid_input_no_flag = {"x": 7, "numbers": [9.9], "kwargs": {"some_key": "some_value"}}
parsed2 = func_schema.params_pydantic_model(**valid_input_no_flag)
args2, kwargs_dict2 = func_schema.to_call_args(parsed2)
result2 = varargs_function(*args2, **kwargs_dict2)
# result2 should be (7, (9.9,), False, {'some_key': 'some_value'})
assert result2 == (7, (9.9,), False, {"some_key": "some_value"})
class Foo(TypedDict):
a: int
b: str
class InnerModel(BaseModel):
a: int
b: str
class OuterModel(BaseModel):
inner: InnerModel
foo: Foo
def complex_args_function(model: OuterModel) -> str:
return f"{model.inner.a}, {model.inner.b}, {model.foo['a']}, {model.foo['b']}"
def test_nested_data_function():
func_schema = function_schema(complex_args_function)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "complex_args_function_args"
# Valid input
model = OuterModel(inner=InnerModel(a=1, b="hello"), foo=Foo(a=2, b="world"))
valid_input = {
"model": model.model_dump(),
}
parsed = func_schema.params_pydantic_model(**valid_input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = complex_args_function(*args, **kwargs_dict)
assert result == "1, hello, 2, world"
def complex_args_and_docs_function(model: OuterModel, some_flag: int = 0) -> str:
"""
This function takes a model and a flag, and returns a string.
Args:
model: A model with an inner and foo field
some_flag: An optional flag with a default of 0
Returns:
A string with the values of the model and flag
"""
return f"{model.inner.a}, {model.inner.b}, {model.foo['a']}, {model.foo['b']}, {some_flag or 0}"
def test_complex_args_and_docs_function():
func_schema = function_schema(complex_args_and_docs_function)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "complex_args_and_docs_function_args"
# Check docstring is parsed correctly
properties = func_schema.params_json_schema.get("properties", {})
assert properties.get("model").get("description") == "A model with an inner and foo field"
assert properties.get("some_flag").get("description") == "An optional flag with a default of 0"
# Valid input
model = OuterModel(inner=InnerModel(a=1, b="hello"), foo=Foo(a=2, b="world"))
valid_input = {
"model": model.model_dump(),
}
parsed = func_schema.params_pydantic_model(**valid_input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = complex_args_and_docs_function(*args, **kwargs_dict)
assert result == "1, hello, 2, world, 0"
# Invalid input: 'some_flag' must be int
with pytest.raises(ValidationError):
func_schema.params_pydantic_model(
**{"model": model.model_dump(), "some_flag": "not an int"}
)
# Valid input: 'some_flag' can be omitted because it has a default
valid_input_no_flag = {"model": model.model_dump()}
parsed2 = func_schema.params_pydantic_model(**valid_input_no_flag)
args2, kwargs_dict2 = func_schema.to_call_args(parsed2)
result2 = complex_args_and_docs_function(*args2, **kwargs_dict2)
assert result2 == "1, hello, 2, world, 0"
def function_with_context(ctx: RunContextWrapper[str], a: int, b: int = 5):
return a + b
def test_function_with_context():
func_schema = function_schema(function_with_context)
assert func_schema.takes_context
context = RunContextWrapper(context="test")
input = {"a": 1, "b": 2}
parsed = func_schema.params_pydantic_model(**input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = function_with_context(context, *args, **kwargs_dict)
assert result == 3
class MyClass:
def foo(self, a: int, b: int = 5):
return a + b
def foo_ctx(self, ctx: RunContextWrapper[str], a: int, b: int = 5):
return a + b
@classmethod
def bar(cls, a: int, b: int = 5):
return a + b
@classmethod
def bar_ctx(cls, ctx: RunContextWrapper[str], a: int, b: int = 5):
return a + b
@staticmethod
def baz(a: int, b: int = 5):
return a + b
@staticmethod
def baz_ctx(ctx: RunContextWrapper[str], a: int, b: int = 5):
return a + b
def test_class_based_functions():
context = RunContextWrapper(context="test")
# Instance method
instance = MyClass()
func_schema = function_schema(instance.foo)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "foo_args"
input = {"a": 1, "b": 2}
parsed = func_schema.params_pydantic_model(**input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = instance.foo(*args, **kwargs_dict)
assert result == 3
# Instance method with context
func_schema = function_schema(instance.foo_ctx)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "foo_ctx_args"
assert func_schema.takes_context
input = {"a": 1, "b": 2}
parsed = func_schema.params_pydantic_model(**input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = instance.foo_ctx(context, *args, **kwargs_dict)
assert result == 3
# Class method
func_schema = function_schema(MyClass.bar)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "bar_args"
input = {"a": 1, "b": 2}
parsed = func_schema.params_pydantic_model(**input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = MyClass.bar(*args, **kwargs_dict)
assert result == 3
# Class method with context
func_schema = function_schema(MyClass.bar_ctx)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "bar_ctx_args"
assert func_schema.takes_context
input = {"a": 1, "b": 2}
parsed = func_schema.params_pydantic_model(**input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = MyClass.bar_ctx(context, *args, **kwargs_dict)
assert result == 3
# Static method
func_schema = function_schema(MyClass.baz)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "baz_args"
input = {"a": 1, "b": 2}
parsed = func_schema.params_pydantic_model(**input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = MyClass.baz(*args, **kwargs_dict)
assert result == 3
# Static method with context
func_schema = function_schema(MyClass.baz_ctx)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "baz_ctx_args"
assert func_schema.takes_context
input = {"a": 1, "b": 2}
parsed = func_schema.params_pydantic_model(**input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = MyClass.baz_ctx(context, *args, **kwargs_dict)
assert result == 3
class MyEnum(str, Enum):
FOO = "foo"
BAR = "bar"
BAZ = "baz"
def enum_and_literal_function(a: MyEnum, b: Literal["a", "b", "c"]) -> str:
return f"{a.value} {b}"
def test_enum_and_literal_function():
func_schema = function_schema(enum_and_literal_function)
assert isinstance(func_schema.params_json_schema, dict)
assert func_schema.params_json_schema.get("title") == "enum_and_literal_function_args"
# Check that the enum values are included in the JSON schema
assert func_schema.params_json_schema.get("$defs", {}).get("MyEnum", {}).get("enum") == [
"foo",
"bar",
"baz",
]
# Check that the enum is expressed as a def
assert (
func_schema.params_json_schema.get("properties", {}).get("a", {}).get("$ref")
== "#/$defs/MyEnum"
)
# Check that the literal values are included in the JSON schema
assert func_schema.params_json_schema.get("properties", {}).get("b", {}).get("enum") == [
"a",
"b",
"c",
]
# Valid input
valid_input = {"a": "foo", "b": "a"}
parsed = func_schema.params_pydantic_model(**valid_input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = enum_and_literal_function(*args, **kwargs_dict)
assert result == "foo a"
# Invalid input: 'a' must be a valid enum value
with pytest.raises(ValidationError):
func_schema.params_pydantic_model(**{"a": "not an enum value", "b": "a"})
# Invalid input: 'b' must be a valid literal value
with pytest.raises(ValidationError):
func_schema.params_pydantic_model(**{"a": "foo", "b": "not a literal value"})
def test_run_context_in_non_first_position_raises_value_error():
# When a parameter (after the first) is annotated as RunContextWrapper,
# function_schema() should raise a UserError.
def func(a: int, context: RunContextWrapper) -> None:
pass
with pytest.raises(UserError):
function_schema(func, use_docstring_info=False)
def test_var_positional_tuple_annotation():
# When a function has a var-positional parameter annotated with a tuple type,
# function_schema() should convert it into a field with type List[<tuple-element>].
def func(*args: tuple[int, ...]) -> int:
total = 0
for arg in args:
total += sum(arg)
return total
fs = function_schema(func, use_docstring_info=False)
properties = fs.params_json_schema.get("properties", {})
assert properties.get("args").get("type") == "array"
assert properties.get("args").get("items").get("type") == "integer"
def test_var_keyword_dict_annotation():
# Case 3:
# When a function has a var-keyword parameter annotated with a dict type,
# function_schema() should convert it into a field with type Dict[<key>, <value>].
def func(**kwargs: dict[str, int]):
return kwargs
fs = function_schema(func, use_docstring_info=False, strict_json_schema=False)
properties = fs.params_json_schema.get("properties", {})
# The name of the field is "kwargs", and it's a JSON object i.e. a dict.
assert properties.get("kwargs").get("type") == "object"
# The values in the dict are integers.
assert properties.get("kwargs").get("additionalProperties").get("type") == "integer"
def test_schema_with_mapping_raises_strict_mode_error():
"""A mapping type is not allowed in strict mode. Same for dicts. Ensure we raise a UserError."""
def func_with_mapping(test_one: Mapping[str, int]) -> str:
return "foo"
with pytest.raises(UserError):
function_schema(func_with_mapping)
def test_name_override_without_docstring() -> None:
"""name_override should be used even when not parsing docstrings."""
def foo(x: int) -> int:
return x
fs = function_schema(foo, use_docstring_info=False, name_override="custom")
assert fs.name == "custom"
assert fs.params_json_schema.get("title") == "custom_args"