feat: Add strict mode option to function_schema and function_tool
This commit is contained in:
parent
9a32ff5172
commit
b7e7fdee55
3 changed files with 55 additions and 1 deletions
|
|
@ -33,7 +33,10 @@ class FuncSchema:
|
|||
"""The signature of the function."""
|
||||
takes_context: bool = False
|
||||
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
|
||||
|
||||
strict_json_schema: bool = True
|
||||
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
||||
as it increases the likelihood of correct JSON input."""
|
||||
|
||||
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""
|
||||
Converts validated data from the Pydantic model into (args, kwargs), suitable for calling
|
||||
|
|
@ -337,4 +340,5 @@ def function_schema(
|
|||
params_json_schema=json_schema,
|
||||
signature=sig,
|
||||
takes_context=takes_context,
|
||||
strict_json_schema=strict_json_schema,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ def function_tool(
|
|||
docstring_style: DocstringStyle | None = None,
|
||||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = None,
|
||||
strict_mode: bool = True,
|
||||
) -> FunctionTool:
|
||||
"""Overload for usage as @function_tool (no parentheses)."""
|
||||
...
|
||||
|
|
@ -150,6 +151,7 @@ def function_tool(
|
|||
docstring_style: DocstringStyle | None = None,
|
||||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = None,
|
||||
strict_mode: bool = True,
|
||||
) -> Callable[[ToolFunction[...]], FunctionTool]:
|
||||
"""Overload for usage as @function_tool(...)."""
|
||||
...
|
||||
|
|
@ -163,6 +165,7 @@ def function_tool(
|
|||
docstring_style: DocstringStyle | None = None,
|
||||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
|
||||
strict_mode: bool = True,
|
||||
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
|
||||
"""
|
||||
Decorator to create a FunctionTool from a function. By default, we will:
|
||||
|
|
@ -186,6 +189,7 @@ def function_tool(
|
|||
failure_error_function: If provided, use this function to generate an error message when
|
||||
the tool call fails. The error message is sent to the LLM. If you pass None, then no
|
||||
error message will be sent and instead an Exception will be raised.
|
||||
strict_mode: If False, allows optional parameters in the function schema.
|
||||
"""
|
||||
|
||||
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
|
||||
|
|
@ -195,6 +199,7 @@ def function_tool(
|
|||
description_override=description_override,
|
||||
docstring_style=docstring_style,
|
||||
use_docstring_info=use_docstring_info,
|
||||
strict_json_schema=strict_mode,
|
||||
)
|
||||
|
||||
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
|
||||
|
|
@ -273,6 +278,7 @@ def function_tool(
|
|||
description=schema.description or "",
|
||||
params_json_schema=schema.params_json_schema,
|
||||
on_invoke_tool=_on_invoke_tool,
|
||||
strict_json_schema=strict_mode,
|
||||
)
|
||||
|
||||
# If func is actually a callable, we were used as @function_tool with no parentheses
|
||||
|
|
|
|||
|
|
@ -142,3 +142,47 @@ async def test_no_error_on_invalid_json_async():
|
|||
tool = will_not_fail_on_bad_json_async
|
||||
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
|
||||
assert result == "error_ModelBehaviorError"
|
||||
|
||||
|
||||
@function_tool(strict_mode=False)
|
||||
def optional_param_function(a: int, b: int | None = None) -> str:
|
||||
if b is None:
|
||||
return f"{a}_no_b"
|
||||
return f"{a}_{b}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optional_param_function():
|
||||
tool = optional_param_function
|
||||
|
||||
input_data = {"a": 5}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "5_no_b"
|
||||
|
||||
input_data = {"a": 5, "b": 10}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "5_10"
|
||||
|
||||
|
||||
@function_tool(strict_mode=False)
|
||||
def multiple_optional_params_function(x: int = 42, y: str = "hello", z: int | None = None) -> str:
|
||||
if z is None:
|
||||
return f"{x}_{y}_no_z"
|
||||
return f"{x}_{y}_{z}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_optional_params_function():
|
||||
tool = multiple_optional_params_function
|
||||
|
||||
input_data = {}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "42_hello_no_z"
|
||||
|
||||
input_data = {"x": 10, "y": "world"}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "10_world_no_z"
|
||||
|
||||
input_data = {"x": 10, "y": "world", "z": 99}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "10_world_99"
|
||||
Loading…
Reference in a new issue