arcade-mcp/arcade/tests/core/test_executor.py
Eric Gustin e636b686c1
Add @tool.deprecated (#247)
## PR Description
Add the ability to mark a tool as deprecated and display the warning in
the user's runtime. This PR also lays the foundation for future work for
emitting other levels of logs (debug, info, etc) that occur during the
tool's execution.

NOTE: Updates to the Arcade Clients (Python and JS) still need to be
done before the deprecation warning is emitted, but this PR needs to be
merged before those updates!

Let's cross our fingers that we'll never need to deprecate
`@tool.deprecated`!

### Example

1. Mark your tool as deprecated
```python
from typing import Annotated

from arcade.sdk import tool


@tool.deprecated("Use the 'Math.AddInt' tool instead.") # order of decorators does not matter
@tool
def add(
    a: Annotated[int, "The first number"], b: Annotated[int, "The second number"]
) -> Annotated[int, "The sum of the two numbers"]:
"""
Add two numbers together
"""
return a + b
```

2. Call the deprecated tool
```python
from arcadepy import Arcade

client = Arcade()

tool_input = {"a": 9001, "b": 42}

response = client.tools.execute(
    tool_name="Math.Add",
    input=tool_input,
    user_id="me@example.com",
)
print(f"The result of adding {tool_input['a']} and {tool_input['b']} is: {response.output.value}")
```

3. Observe the DeprecationWarning:
``` 
❯ python examples/call_a_tool_directly.py 
/Users/ericgustin/repos/Team/arcade-ai/examples/call_a_tool_directly.py:22: DeprecationWarning: 'Math.Add' is deprecated: Use the `Math.AddInt` tool instead.
  response = client.tools.execute(
The result of adding 9001 and 42 is: 9043
```
2025-02-18 13:27:49 -08:00

186 lines
5.7 KiB
Python

from typing import Annotated
import pytest
from arcade.core.catalog import ToolCatalog
from arcade.core.executor import ToolExecutor
from arcade.core.schema import ToolCallError, ToolCallLog, ToolCallOutput, ToolContext
from arcade.sdk import tool
from arcade.sdk.errors import RetryableToolError, ToolExecutionError
@tool
def simple_tool(inp: Annotated[str, "input"]) -> Annotated[str, "output"]:
"""Simple tool"""
return inp
@tool.deprecated("Use simple_tool instead")
@tool
def simple_deprecated_tool(inp: Annotated[str, "input"]) -> Annotated[str, "output"]:
"""Simple tool that is deprecated"""
return inp
@tool
def retryable_error_tool() -> Annotated[str, "output"]:
"""Tool that raises a retryable error"""
raise RetryableToolError("test", "test", "test", 1000)
@tool
def exec_error_tool() -> Annotated[str, "output"]:
"""Tool that raises an error"""
raise ToolExecutionError("test", "test")
@tool
def unexpected_error_tool() -> Annotated[str, "output"]:
"""Tool that raises an unexpected error"""
raise RuntimeError("test")
@tool
def bad_output_error_tool() -> Annotated[str, "output"]:
"""tool that returns a bad output type"""
return {"output": "test"}
# ---- Test Driver ----
catalog = ToolCatalog()
catalog.add_tool(simple_tool, "simple_toolkit")
catalog.add_tool(simple_deprecated_tool, "simple_toolkit")
catalog.add_tool(retryable_error_tool, "simple_toolkit")
catalog.add_tool(exec_error_tool, "simple_toolkit")
catalog.add_tool(unexpected_error_tool, "simple_toolkit")
catalog.add_tool(bad_output_error_tool, "simple_toolkit")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"tool_func, inputs, expected_output",
[
(simple_tool, {"inp": "test"}, ToolCallOutput(value="test")),
(
simple_deprecated_tool,
{"inp": "test"},
ToolCallOutput(
value="test",
logs=[
ToolCallLog(
message="Use simple_tool instead",
level="warning",
subtype="deprecation",
)
],
),
),
(
retryable_error_tool,
{},
ToolCallOutput(
error=ToolCallError(
message="test",
developer_message="test",
additional_prompt_content="test",
retry_after_ms=1000,
can_retry=True,
)
),
),
(
exec_error_tool,
{},
ToolCallOutput(
error=ToolCallError(
message="test",
developer_message="test",
)
),
),
(
unexpected_error_tool,
{},
ToolCallOutput(
error=ToolCallError(
message="Error in execution of UnexpectedErrorTool",
developer_message="Error in unexpected_error_tool: test",
)
),
),
(
simple_tool,
{"inp": {"test": "test"}}, # takes in a string not a dict
ToolCallOutput(
error=ToolCallError(
message="Error in tool input deserialization",
developer_message=None, # can't gaurantee this will be the same
)
),
),
(
bad_output_error_tool,
{},
ToolCallOutput(
error=ToolCallError(
message="Failed to serialize tool output",
developer_message=None, # can't gaurantee this will be the same
)
),
),
],
ids=[
"simple_tool",
"simple_deprecated_tool",
"retryable_error_tool",
"exec_error_tool",
"unexpected_error_tool",
"invalid_input_type",
"bad_output_type",
],
)
async def test_tool_executor(tool_func, inputs, expected_output):
tool_definition = catalog.find_tool_by_func(tool_func)
dummy_context = ToolContext()
full_tool = catalog.get_tool(tool_definition.get_fully_qualified_name())
output = await ToolExecutor.run(
func=tool_func,
definition=tool_definition,
input_model=full_tool.input_model,
output_model=full_tool.output_model,
context=dummy_context,
**inputs,
)
check_output(output, expected_output)
def check_output(output: ToolCallOutput, expected_output: ToolCallOutput):
# execution error in tool
if output.error:
assert output.error.message == expected_output.error.message
if expected_output.error.developer_message:
assert output.error.developer_message == expected_output.error.developer_message
if expected_output.error.traceback_info:
assert output.error.traceback_info == expected_output.error.traceback_info
assert output.error.can_retry == expected_output.error.can_retry
assert (
output.error.additional_prompt_content
== expected_output.error.additional_prompt_content
)
assert output.error.retry_after_ms == expected_output.error.retry_after_ms
# normal tool execution
else:
assert output.value == expected_output.value
# check logs
output_logs = output.logs or []
expected_logs = expected_output.logs or []
assert len(output_logs) == len(expected_logs)
for output_log, expected_log in zip(output_logs, expected_logs):
assert output_log.message == expected_log.message
assert output_log.level == expected_log.level
assert output_log.subtype == expected_log.subtype