arcade-mcp/arcade/tests/core/test_executor.py
Sam Partee 68a4caff98
CLI Engine Env passing and Tool Executor cleanup (#95)
This PR introduces the following changes:

- **Engine Environment Configuration**: Adds support for specifying an
environment variables file for the engine via the `arcade dev` CLI
command.
- **Configuration File Handling**: Refactors configuration file handling
in the CLI launcher to generalize logic for locating configuration
files.
- **Tool Execution Logging**: Enhances logging in `BaseActor` to include
execution duration and adjusts logging levels for better visibility.
- **Enhanced Tool Exception Handling**: Improves exception handling in
`ToolExecutor` and updates the `@tool` decorator to ensure proper
propagation and handling of exceptions raised during tool execution.
2024-10-07 17:49:34 -07:00

154 lines
4.6 KiB
Python

from typing import Annotated
import pytest
from arcade.core.catalog import ToolCatalog
from arcade.core.executor import ToolExecutor
from arcade.core.schema import ToolCallError, ToolCallOutput, ToolContext
from arcade.sdk import tool
from arcade.sdk.error import RetryableToolError, ToolExecutionError
@tool
def simple_tool(inp: Annotated[str, "input"]) -> Annotated[str, "output"]:
"""Simple tool"""
return inp
@tool
def retryable_error_tool() -> Annotated[str, "output"]:
"""Tool that raises a retryable error"""
raise RetryableToolError("test", "test", "test", 1000)
@tool
def exec_error_tool() -> Annotated[str, "output"]:
"""Tool that raises an error"""
raise ToolExecutionError("test", "test")
@tool
def unexpected_error_tool() -> Annotated[str, "output"]:
"""Tool that raises an unexpected error"""
raise RuntimeError("test")
@tool
def bad_output_error_tool() -> Annotated[str, "output"]:
"""tool that returns a bad output type"""
return {"output": "test"}
# ---- Test Driver ----
catalog = ToolCatalog()
catalog.add_tool(simple_tool, "simple_toolkit")
catalog.add_tool(retryable_error_tool, "simple_toolkit")
catalog.add_tool(exec_error_tool, "simple_toolkit")
catalog.add_tool(unexpected_error_tool, "simple_toolkit")
catalog.add_tool(bad_output_error_tool, "simple_toolkit")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"tool_func, inputs, expected_output",
[
(simple_tool, {"inp": "test"}, ToolCallOutput(value="test")),
(
retryable_error_tool,
{},
ToolCallOutput(
error=ToolCallError(
message="test",
developer_message="test",
additional_prompt_content="test",
retry_after_ms=1000,
can_retry=True,
)
),
),
(
exec_error_tool,
{},
ToolCallOutput(
error=ToolCallError(
message="test",
developer_message="test",
)
),
),
(
unexpected_error_tool,
{},
ToolCallOutput(
error=ToolCallError(
message="Error in execution of UnexpectedErrorTool",
developer_message="Error in unexpected_error_tool: test",
)
),
),
(
simple_tool,
{"inp": {"test": "test"}}, # takes in a string not a dict
ToolCallOutput(
error=ToolCallError(
message="Error in tool input deserialization",
developer_message=None, # can't gaurantee this will be the same
)
),
),
(
bad_output_error_tool,
{},
ToolCallOutput(
error=ToolCallError(
message="Failed to serialize tool output",
developer_message=None, # can't gaurantee this will be the same
)
),
),
],
ids=[
"simple_tool",
"retryable_error_tool",
"exec_error_tool",
"unexpected_error_tool",
"invalid_input_type",
"bad_output_type",
],
)
async def test_tool_executor(tool_func, inputs, expected_output):
tool_definition = catalog.find_tool_by_func(tool_func)
dummy_context = ToolContext()
full_tool = catalog.get_tool(tool_definition.get_fully_qualified_name())
output = await ToolExecutor.run(
func=tool_func,
definition=tool_definition,
input_model=full_tool.input_model,
output_model=full_tool.output_model,
context=dummy_context,
**inputs,
)
check_output(output, expected_output)
def check_output(output: ToolCallOutput, expected_output: ToolCallOutput):
# execution error in tool
if output.error:
assert output.error.message == expected_output.error.message
if expected_output.error.developer_message:
assert output.error.developer_message == expected_output.error.developer_message
if expected_output.error.traceback_info:
assert output.error.traceback_info == expected_output.error.traceback_info
assert output.error.can_retry == expected_output.error.can_retry
assert (
output.error.additional_prompt_content
== expected_output.error.additional_prompt_content
)
assert output.error.retry_after_ms == expected_output.error.retry_after_ms
# normal tool execution
else:
assert output.value == expected_output.value