feat: Add strict_mode option to function_schema and function_tool (#60)
This PR introduces a `strict_mode: bool = True` option to `@function_tool`, allowing optional parameters when set to False. This change enables more flexibility while maintaining strict JSON schema validation by default. resolves #43 ## Changes: - Added `strict_mode` parameter to `@function_tool` and passed it to `function_schema` and `FunctionTool`. - Updated `function_schema.py` to respect `strict_mode` and allow optional parameters when set to False. - Added unit tests to verify optional parameters work correctly, including multiple optional params with different types. ## Tests: - Verified function calls with missing optional parameters behave as expected. - Added async tests to validate behavior under different configurations.
This commit is contained in:
commit
951193bd21
3 changed files with 61 additions and 1 deletions
|
|
@ -33,6 +33,9 @@ class FuncSchema:
|
|||
"""The signature of the function."""
|
||||
takes_context: bool = False
|
||||
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
|
||||
strict_json_schema: bool = True
|
||||
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
||||
as it increases the likelihood of correct JSON input."""
|
||||
|
||||
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""
|
||||
|
|
@ -337,4 +340,5 @@ def function_schema(
|
|||
params_json_schema=json_schema,
|
||||
signature=sig,
|
||||
takes_context=takes_context,
|
||||
strict_json_schema=strict_json_schema,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ def function_tool(
|
|||
docstring_style: DocstringStyle | None = None,
|
||||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = None,
|
||||
strict_mode: bool = True,
|
||||
) -> FunctionTool:
|
||||
"""Overload for usage as @function_tool (no parentheses)."""
|
||||
...
|
||||
|
|
@ -150,6 +151,7 @@ def function_tool(
|
|||
docstring_style: DocstringStyle | None = None,
|
||||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = None,
|
||||
strict_mode: bool = True,
|
||||
) -> Callable[[ToolFunction[...]], FunctionTool]:
|
||||
"""Overload for usage as @function_tool(...)."""
|
||||
...
|
||||
|
|
@ -163,6 +165,7 @@ def function_tool(
|
|||
docstring_style: DocstringStyle | None = None,
|
||||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
|
||||
strict_mode: bool = True,
|
||||
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
|
||||
"""
|
||||
Decorator to create a FunctionTool from a function. By default, we will:
|
||||
|
|
@ -186,6 +189,8 @@ def function_tool(
|
|||
failure_error_function: If provided, use this function to generate an error message when
|
||||
the tool call fails. The error message is sent to the LLM. If you pass None, then no
|
||||
error message will be sent and instead an Exception will be raised.
|
||||
strict_mode: If False, parameters with default values become optional in the
|
||||
function schema.
|
||||
"""
|
||||
|
||||
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
|
||||
|
|
@ -195,6 +200,7 @@ def function_tool(
|
|||
description_override=description_override,
|
||||
docstring_style=docstring_style,
|
||||
use_docstring_info=use_docstring_info,
|
||||
strict_json_schema=strict_mode,
|
||||
)
|
||||
|
||||
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
|
||||
|
|
@ -273,6 +279,7 @@ def function_tool(
|
|||
description=schema.description or "",
|
||||
params_json_schema=schema.params_json_schema,
|
||||
on_invoke_tool=_on_invoke_tool,
|
||||
strict_json_schema=strict_mode,
|
||||
)
|
||||
|
||||
# If func is actually a callable, we were used as @function_tool with no parentheses
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -142,3 +142,52 @@ async def test_no_error_on_invalid_json_async():
|
|||
tool = will_not_fail_on_bad_json_async
|
||||
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
|
||||
assert result == "error_ModelBehaviorError"
|
||||
|
||||
|
||||
@function_tool(strict_mode=False)
|
||||
def optional_param_function(a: int, b: Optional[int] = None) -> str:
|
||||
if b is None:
|
||||
return f"{a}_no_b"
|
||||
return f"{a}_{b}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optional_param_function():
|
||||
tool = optional_param_function
|
||||
|
||||
input_data = {"a": 5}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "5_no_b"
|
||||
|
||||
input_data = {"a": 5, "b": 10}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "5_10"
|
||||
|
||||
|
||||
@function_tool(strict_mode=False)
|
||||
def multiple_optional_params_function(
|
||||
x: int = 42,
|
||||
y: str = "hello",
|
||||
z: Optional[int] = None,
|
||||
) -> str:
|
||||
if z is None:
|
||||
return f"{x}_{y}_no_z"
|
||||
return f"{x}_{y}_{z}"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_optional_params_function():
|
||||
tool = multiple_optional_params_function
|
||||
|
||||
input_data: dict[str,Any] = {}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "42_hello_no_z"
|
||||
|
||||
input_data = {"x": 10, "y": "world"}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "10_world_no_z"
|
||||
|
||||
input_data = {"x": 10, "y": "world", "z": 99}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "10_world_99"
|
||||
|
|
|
|||
Loading…
Reference in a new issue