import json from typing import Any import pytest from pydantic import BaseModel from typing_extensions import TypedDict from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool from agents.tool import default_tool_error_function def argless_function() -> str: return "ok" @pytest.mark.asyncio async def test_argless_function(): tool = function_tool(argless_function) assert tool.name == "argless_function" result = await tool.on_invoke_tool(RunContextWrapper(None), "") assert result == "ok" def argless_with_context(ctx: RunContextWrapper[str]) -> str: return "ok" @pytest.mark.asyncio async def test_argless_with_context(): tool = function_tool(argless_with_context) assert tool.name == "argless_with_context" result = await tool.on_invoke_tool(RunContextWrapper(None), "") assert result == "ok" # Extra JSON should not raise an error result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}') assert result == "ok" def simple_function(a: int, b: int = 5): return a + b @pytest.mark.asyncio async def test_simple_function(): tool = function_tool(simple_function, failure_error_function=None) assert tool.name == "simple_function" result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}') assert result == 6 result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}') assert result == 3 # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): await tool.on_invoke_tool(RunContextWrapper(None), "") class Foo(BaseModel): a: int b: int = 5 class Bar(TypedDict): x: str y: int def complex_args_function(foo: Foo, bar: Bar, baz: str = "hello"): return f"{foo.a + foo.b} {bar['x']}{bar['y']} {baz}" @pytest.mark.asyncio async def test_complex_args_function(): tool = function_tool(complex_args_function, failure_error_function=None) assert tool.name == "complex_args_function" valid_json = json.dumps( { "foo": Foo(a=1).model_dump(), "bar": Bar(x="hello", y=10), } ) result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) assert result == "6 hello10 hello" valid_json = json.dumps( { "foo": Foo(a=1, b=2).model_dump(), "bar": Bar(x="hello", y=10), } ) result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) assert result == "3 hello10 hello" valid_json = json.dumps( { "foo": Foo(a=1, b=2).model_dump(), "bar": Bar(x="hello", y=10), "baz": "world", } ) result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) assert result == "3 hello10 world" # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): await tool.on_invoke_tool(RunContextWrapper(None), '{"foo": {"a": 1}}') def test_function_config_overrides(): tool = function_tool(simple_function, name_override="custom_name") assert tool.name == "custom_name" tool = function_tool(simple_function, description_override="custom description") assert tool.description == "custom description" tool = function_tool( simple_function, name_override="custom_name", description_override="custom description", ) assert tool.name == "custom_name" assert tool.description == "custom description" def test_func_schema_is_strict(): tool = function_tool(simple_function) assert tool.strict_json_schema, "Should be strict by default" assert ( "additionalProperties" in tool.params_json_schema and not tool.params_json_schema["additionalProperties"] ) tool = function_tool(complex_args_function) assert tool.strict_json_schema, "Should be strict by default" assert ( "additionalProperties" in tool.params_json_schema and not tool.params_json_schema["additionalProperties"] ) @pytest.mark.asyncio async def test_manual_function_tool_creation_works(): def do_some_work(data: str) -> str: return f"{data}_done" class FunctionArgs(BaseModel): data: str async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: parsed = FunctionArgs.model_validate_json(args) return do_some_work(data=parsed.data) tool = FunctionTool( name="test", description="Processes extracted user data", params_json_schema=FunctionArgs.model_json_schema(), on_invoke_tool=run_function, ) assert tool.name == "test" assert tool.description == "Processes extracted user data" for key, value in FunctionArgs.model_json_schema().items(): assert tool.params_json_schema[key] == value assert tool.strict_json_schema result = await tool.on_invoke_tool(RunContextWrapper(None), '{"data": "hello"}') assert result == "hello_done" tool_not_strict = FunctionTool( name="test", description="Processes extracted user data", params_json_schema=FunctionArgs.model_json_schema(), on_invoke_tool=run_function, strict_json_schema=False, ) assert not tool_not_strict.strict_json_schema assert "additionalProperties" not in tool_not_strict.params_json_schema result = await tool_not_strict.on_invoke_tool( RunContextWrapper(None), '{"data": "hello", "bar": "baz"}' ) assert result == "hello_done" @pytest.mark.asyncio async def test_function_tool_default_error_works(): def my_func(a: int, b: int = 5): raise ValueError("test") tool = function_tool(my_func) ctx = RunContextWrapper(None) result = await tool.on_invoke_tool(ctx, "") assert "Invalid JSON" in str(result) result = await tool.on_invoke_tool(ctx, "{}") assert "Invalid JSON" in str(result) result = await tool.on_invoke_tool(ctx, '{"a": 1}') assert result == default_tool_error_function(ctx, ValueError("test")) result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}') assert result == default_tool_error_function(ctx, ValueError("test")) @pytest.mark.asyncio async def test_sync_custom_error_function_works(): def my_func(a: int, b: int = 5): raise ValueError("test") def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> str: return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) ctx = RunContextWrapper(None) result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" result = await tool.on_invoke_tool(ctx, "{}") assert result == "error_ModelBehaviorError" result = await tool.on_invoke_tool(ctx, '{"a": 1}') assert result == "error_ValueError" result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}') assert result == "error_ValueError" @pytest.mark.asyncio async def test_async_custom_error_function_works(): async def my_func(a: int, b: int = 5): raise ValueError("test") def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> str: return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) ctx = RunContextWrapper(None) result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" result = await tool.on_invoke_tool(ctx, "{}") assert result == "error_ModelBehaviorError" result = await tool.on_invoke_tool(ctx, '{"a": 1}') assert result == "error_ValueError" result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}') assert result == "error_ValueError"