feat: Add strict mode option to function_schema and function_tool

This commit is contained in:
Jai0401 2025-03-12 11:45:32 +05:30
parent 9a32ff5172
commit b7e7fdee55
3 changed files with 55 additions and 1 deletions

View file

@ -33,7 +33,10 @@ class FuncSchema:
"""The signature of the function."""
takes_context: bool = False
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
Converts validated data from the Pydantic model into (args, kwargs), suitable for calling
@ -337,4 +340,5 @@ def function_schema(
params_json_schema=json_schema,
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
)

View file

@ -137,6 +137,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
@ -150,6 +151,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
@ -163,6 +165,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
@ -186,6 +189,7 @@ def function_tool(
failure_error_function: If provided, use this function to generate an error message when
the tool call fails. The error message is sent to the LLM. If you pass None, then no
error message will be sent and instead an Exception will be raised.
strict_mode: If False, allows optional parameters in the function schema.
"""
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@ -195,6 +199,7 @@ def function_tool(
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
)
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
@ -273,6 +278,7 @@ def function_tool(
description=schema.description or "",
params_json_schema=schema.params_json_schema,
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
)
# If func is actually a callable, we were used as @function_tool with no parentheses

View file

@ -142,3 +142,47 @@ async def test_no_error_on_invalid_json_async():
tool = will_not_fail_on_bad_json_async
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
assert result == "error_ModelBehaviorError"
@function_tool(strict_mode=False)
def optional_param_function(a: int, b: int | None = None) -> str:
if b is None:
return f"{a}_no_b"
return f"{a}_{b}"
@pytest.mark.asyncio
async def test_optional_param_function():
tool = optional_param_function
input_data = {"a": 5}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_no_b"
input_data = {"a": 5, "b": 10}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_10"
@function_tool(strict_mode=False)
def multiple_optional_params_function(x: int = 42, y: str = "hello", z: int | None = None) -> str:
if z is None:
return f"{x}_{y}_no_z"
return f"{x}_{y}_{z}"
@pytest.mark.asyncio
async def test_multiple_optional_params_function():
tool = multiple_optional_params_function
input_data = {}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "42_hello_no_z"
input_data = {"x": 10, "y": "world"}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_no_z"
input_data = {"x": 10, "y": "world", "z": 99}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_99"