Add `GET /v1/tools/list` (#100)

Add retrieving the list of available tool definitions
that can be called. 

essential to working with frameworks like langchain/graph
This commit is contained in:
Sam Partee 2024-10-09 21:02:23 -07:00 committed by GitHub
parent 20170d04e0
commit 6b716d6dde
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1159 additions and 586 deletions

View file

@ -28,15 +28,22 @@ check: ## Run code quality tools.
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@cd arcade && poetry run pytest -v --cov --cov-config=pyproject.toml --cov-report=xml
@cd arcade && poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: test-toolkits
test-toolkits: ## Iterate over all toolkits and run pytest on each one
@echo "🚀 Testing code in toolkits: Running pytest"
@for dir in toolkits/*/ ; do \
(cd $$dir && poetry run pytest -v --cov --cov-config=pyproject.toml --cov-report=xml || exit 1); \
(cd $$dir && poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml || exit 1); \
done
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
@cd arcade && coverage report
@echo "Generating coverage report"
@cd arcade && coverage html
.PHONY: set-version
set-version: ## Set the version in the pyproject.toml file
@echo "🚀 Setting version in pyproject.toml"

View file

@ -0,0 +1,263 @@
from typing import TYPE_CHECKING, Any
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from arcade.core.config_model import Config
from arcade.core.schema import ToolDefinition
if TYPE_CHECKING:
from arcade.sdk.eval.eval import EvaluationResult
console = Console()
def display_tools_table(tools: list[ToolDefinition]) -> None:
"""
Display a table of tools with their name, description, package, and version.
"""
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Name")
table.add_column("Description")
table.add_column("Package")
table.add_column("Version")
for tool in sorted(tools, key=lambda x: x.toolkit.name):
table.add_row(
str(tool.get_fully_qualified_name()),
tool.description.split("\n")[0] if tool.description else "",
tool.toolkit.name,
tool.toolkit.version,
)
console.print(table)
def display_tool_details(tool: ToolDefinition) -> None:
"""
Display detailed information about a specific tool using multiple panels.
"""
# Description Panel
description_panel = Panel(
tool.description or "No description available.",
title=f"Tool: {tool.name}",
border_style="cyan",
)
# Inputs Panel
inputs = tool.inputs.parameters
if inputs:
inputs_table = Table(show_header=True, header_style="bold green")
inputs_table.add_column("Name", style="cyan")
inputs_table.add_column("Type", style="magenta")
inputs_table.add_column("Required", style="yellow")
inputs_table.add_column("Description", style="white")
inputs_table.add_column("Default", style="blue")
for param in inputs:
# Since InputParameter does not have a default field, we use "N/A"
default_value = "N/A"
if param.value_schema.enum:
default_value = f"One of {param.value_schema.enum}"
inputs_table.add_row(
param.name,
param.value_schema.val_type,
str(param.required),
param.description or "",
default_value,
)
inputs_panel = Panel(
inputs_table,
title="Input Parameters",
border_style="green",
)
else:
inputs_panel = Panel(
"No input parameters.",
title="Input Parameters",
border_style="green",
)
# Output Panel
output = tool.output
if output:
output_description = output.description or "No description available."
output_types = ", ".join(output.available_modes)
output_val_type = output.value_schema.val_type if output.value_schema else "N/A"
output_details = Text.assemble(
("Description: ", "bold"),
(output_description, ""),
"\n",
("Available Modes: ", "bold"),
(output_types, ""),
"\n",
("Value Type: ", "bold"),
(output_val_type, ""),
)
output_panel = Panel(
output_details,
title="Expected Output",
border_style="blue",
)
else:
output_panel = Panel(
"No output information available.",
title="Expected Output",
border_style="blue",
)
# Combine all panels vertically
console.print(description_panel)
console.print(inputs_panel)
console.print(output_panel)
def display_tool_messages(tool_messages: list[dict]) -> None:
for message in tool_messages:
if message["role"] == "assistant":
for tool_call in message.get("tool_calls", []):
console.print(
f"[bright_black][bold]Called tool '{tool_call['function']['name']}'[/bold]\n[bold]Parameters:[/bold]{tool_call['function']['arguments']}[/bright_black]"
)
elif message["role"] == "tool":
console.print(
f"[bright_black][bold]'{message['name']}' tool returned:[/bold]{message['content']}[/bright_black]"
)
def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool = False) -> None:
"""
Display evaluation results in a format inspired by pytest's output.
Args:
results: List of dictionaries containing evaluation results for each model.
show_details: Whether to show detailed results for each case.
"""
total_passed = 0
total_failed = 0
total_warned = 0
total_cases = 0
for eval_suite in results:
for model_results in eval_suite:
model = model_results.get("model", "Unknown Model")
rubric = model_results.get("rubric", "Unknown Rubric")
cases = model_results.get("cases", [])
total_cases += len(cases)
console.print(f"[bold]Model:[/bold] [bold magenta]{model}[/bold magenta]")
if show_details:
console.print(f"[bold magenta]{rubric}[/bold magenta]")
for case in cases:
evaluation = case["evaluation"]
status = (
"[green]PASSED[/green]"
if evaluation.passed
else "[yellow]WARNED[/yellow]"
if evaluation.warning
else "[red]FAILED[/red]"
)
if evaluation.passed:
total_passed += 1
elif evaluation.warning:
total_warned += 1
else:
total_failed += 1
# Display one-line summary for each case with score as a percentage
score_percentage = evaluation.score * 100
console.print(f"{status} {case['name']} -- Score: {score_percentage:.2f}%")
if show_details:
# Show detailed information for each case
console.print(f"[bold]User Input:[/bold] {case['input']}\n")
console.print("[bold]Details:[/bold]")
console.print(_format_evaluation(evaluation))
console.print("-" * 80)
# Summary
summary = (
f"[bold]Summary -- [/bold]Total: {total_cases} -- [green]Passed: {total_passed}[/green]"
)
if total_warned > 0:
summary += f" -- [yellow]Warnings: {total_warned}[/yellow]"
if total_failed > 0:
summary += f" -- [red]Failed: {total_failed}[/red]"
console.print(summary + "\n")
def _format_evaluation(evaluation: "EvaluationResult") -> str:
"""
Format evaluation results with color-coded matches and scores.
Args:
evaluation: An EvaluationResult object containing the evaluation results.
Returns:
A formatted string representation of the evaluation details.
"""
result_lines = []
if evaluation.failure_reason:
result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}")
else:
for critic_result in evaluation.results:
match_color = "green" if critic_result["match"] else "red"
field = critic_result["field"]
score = critic_result["score"]
weight = critic_result["weight"]
expected = critic_result["expected"]
actual = critic_result["actual"]
result_lines.append(
f"[bold]{field}:[/bold] "
f"[{match_color}]Match: {critic_result['match']}"
f"\n Score: {score:.2f}/{weight:.2f}[/{match_color}]"
f"\n Expected: {expected}"
f"\n Actual: {actual}"
)
return "\n".join(result_lines)
def display_arcade_chat_header(config: Config, stream: bool) -> None:
chat_header = Text.assemble(
"\n",
(
"=== Arcade AI Chat ===",
"bold magenta underline",
),
"\n",
"\n",
"Chatting with Arcade Engine at ",
(
config.engine_url,
"bold blue",
),
)
if stream:
chat_header.append(" (streaming)")
console.print(chat_header)
def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
"""
Display the configuration details as a table using Rich library.
"""
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Section")
table.add_column("Name")
table.add_column("Value")
for section_name in config.model_dump():
section = getattr(config, section_name)
if section:
section = section.dict()
first = True
for name, value in section.items():
if first:
table.add_row(section_name, name, str(value))
first = False
else:
table.add_row("", name, str(value))
table.add_row("", "", "")
console.print(table)

View file

@ -11,27 +11,32 @@ import typer
from openai import OpenAIError
from rich.console import Console
from rich.markup import escape
from rich.table import Table
from rich.text import Text
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
from arcade.cli.display import (
display_arcade_chat_header,
display_config_as_table,
display_eval_results,
display_tool_details,
display_tool_messages,
display_tools_table,
)
from arcade.cli.launcher import start_servers
from arcade.cli.utils import (
OrderCommands,
apply_config_overrides,
create_cli_catalog,
display_eval_results,
display_tool_messages,
get_config_with_overrides,
get_eval_files,
get_tools_from_engine,
handle_chat_interaction,
is_authorization_pending,
load_eval_suites, # Import the new function
load_eval_suites,
log_engine_health,
validate_and_get_config,
wait_for_authorization_completion,
)
from arcade.client import Arcade
from arcade.client.errors import EngineNotHealthyError, EngineOfflineError
from arcade.core.config_model import Config
cli = typer.Typer(
cls=OrderCommands,
@ -44,27 +49,6 @@ cli = typer.Typer(
console = Console()
def _get_config_with_overrides(
force_tls: bool,
force_no_tls: bool,
host_input: str | None = None,
port_input: int | None = None,
) -> Config:
"""
Get the config with CLI-specific optional overrides applied.
"""
config = validate_and_get_config()
if not force_tls and not force_no_tls:
tls_input = None
elif force_no_tls:
tls_input = False
else:
tls_input = True
apply_config_overrides(config, host_input, port_input, tls_input)
return config
@cli.command(help="Log in to Arcade Cloud", rich_help_panel="User")
def login(
host: str = typer.Option(
@ -136,44 +120,75 @@ def new(
@cli.command(
help="Show the installed toolkits",
help="Show the installed toolkits or details of a specific tool",
rich_help_panel="Tool Development",
)
def show(
toolkit: Optional[str] = typer.Option(
None, "-t", "--toolkit", help="The toolkit to show the tools of"
),
tool: Optional[str] = typer.Option(
None, "-T", "--tool", help="The specific tool to show details for"
),
host: Optional[str] = typer.Option(
None,
"-h",
"--host",
help="The Arcade Engine address to send chat requests to.",
),
port: Optional[int] = typer.Option(
None,
"-p",
"--port",
help="The port of the Arcade Engine.",
),
force_tls: bool = typer.Option(
False,
"--tls",
help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.",
),
force_no_tls: bool = typer.Option(
False,
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
),
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
) -> None:
"""
Show the available tools in an actor or toolkit
Show the available toolkits or detailed information about a specific tool.
"""
try:
catalog = create_cli_catalog(toolkit=toolkit)
if not host:
catalog = create_cli_catalog(toolkit=toolkit)
tools = [t.definition for t in list(catalog)]
else:
tools = get_tools_from_engine(host, port, force_tls, force_no_tls, toolkit)
# Create a table with Rich library
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Name")
table.add_column("Description")
table.add_column("Package")
table.add_column("Version")
if tool:
# Display detailed information for the specified tool
tool_def = next(
(
t
for t in tools
if t.get_fully_qualified_name().name == tool
or str(t.get_fully_qualified_name()) == tool
),
None,
)
if not tool_def:
console.print(f"❌ Tool '{tool}' not found.", style="bold red")
typer.Exit(code=1)
else:
display_tool_details(tool_def)
else:
# Display the list of tools as a table
display_tools_table(tools)
tool_names = catalog.get_tool_names()
for tool_name in tool_names:
tool = catalog.get_tool(tool_name)
package = tool.meta.package if tool.meta.package else tool.meta.toolkit
table.add_row(str(tool_name), tool.description, package, tool.version)
console.print(table)
# used when debugging a broken package on import.
# `arcade show` is the first command used after
# a toolkit package is created.
except Exception as e:
if debug:
raise
error_message = f"❌ Failed to List tools: {escape(str(e))}"
error_message = f"❌ Failed to list tools: {escape(str(e))}"
console.print(error_message, style="bold red")
@ -211,7 +226,7 @@ def chat(
"""
Chat with a language model.
"""
config = _get_config_with_overrides(force_tls, force_no_tls, host, port)
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
user_email = config.user.email if config.user else None
@ -321,73 +336,6 @@ def config(
raise typer.Exit(code=1)
def display_arcade_chat_header(config: Config, stream: bool) -> None:
chat_header = Text.assemble(
"\n",
(
"=== Arcade AI Chat ===",
"bold magenta underline",
),
"\n",
"\n",
"Chatting with Arcade Engine at ",
(
config.engine_url,
"bold blue",
),
)
if stream:
chat_header.append(" (streaming)")
console.print(chat_header)
def log_engine_health(client: Arcade) -> None:
try:
client.health.check()
except EngineNotHealthyError as e:
console.print(
"[bold][yellow]⚠️ Warning: "
+ str(e)
+ " ("
+ "[/yellow]"
+ "[red]"
+ str(e.status_code)
+ "[/red]"
+ "[yellow])[/yellow][/bold]"
)
except EngineOfflineError:
console.print(
"⚠️ Warning: Arcade Engine was unreachable. (Is it running?)",
style="bold yellow",
)
def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
"""
Display the configuration details as a table using Rich library.
"""
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Section")
table.add_column("Name")
table.add_column("Value")
for section_name in config.model_dump():
section = getattr(config, section_name)
if section:
section = section.dict()
first = True
for name, value in section.items():
if first:
table.add_row(section_name, name, str(value))
first = False
else:
table.add_row("", name, str(value))
table.add_row("", "", "")
console.print(table)
@cli.command(help="Run tool calling evaluations", rich_help_panel="Tool Development")
def evals(
directory: str = typer.Argument(".", help="Directory containing evaluation files"),
@ -428,7 +376,7 @@ def evals(
Find all files starting with 'eval_' in the given directory,
execute any functions decorated with @tool_eval, and display the results.
"""
config = _get_config_with_overrides(force_tls, force_no_tls, host, port)
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
models_list = models.split(",") # Use 'models_list' to avoid shadowing

View file

@ -1,7 +1,7 @@
import importlib.util
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Union
from typing import Callable, Union
import typer
from openai.resources.chat.completions import ChatCompletionChunk, Stream
@ -13,16 +13,14 @@ from typer.core import TyperGroup
from typer.models import Context
from arcade.client.client import Arcade
from arcade.client.errors import APITimeoutError
from arcade.client.errors import APITimeoutError, EngineNotHealthyError, EngineOfflineError
from arcade.client.schema import AuthResponse
from arcade.core.catalog import ToolCatalog
from arcade.core.config_model import Config
from arcade.core.errors import ToolkitLoadError
from arcade.core.schema import ToolDefinition
from arcade.core.toolkit import Toolkit
if TYPE_CHECKING:
from arcade.sdk.eval.eval import EvaluationResult
console = Console()
@ -64,17 +62,37 @@ def create_cli_catalog(
return catalog
def display_tool_messages(tool_messages: list[dict]) -> None:
for message in tool_messages:
if message["role"] == "assistant":
for tool_call in message.get("tool_calls", []):
console.print(
f"[bright_black][bold]Called tool '{tool_call['function']['name']}'[/bold]\n[bold]Parameters:[/bold]{tool_call['function']['arguments']}[/bright_black]"
)
elif message["role"] == "tool":
console.print(
f"[bright_black][bold]'{message['name']}' tool returned:[/bold]{message['content']}[/bright_black]"
)
def get_config_with_overrides(
force_tls: bool,
force_no_tls: bool,
host_input: str | None = None,
port_input: int | None = None,
) -> Config:
"""
Get the config with CLI-specific optional overrides applied.
"""
config = validate_and_get_config()
if not force_tls and not force_no_tls:
tls_input = None
elif force_no_tls:
tls_input = False
else:
tls_input = True
apply_config_overrides(config, host_input, port_input, tls_input)
return config
def get_tools_from_engine(
host: str,
port: int | None = None,
force_tls: bool = False,
force_no_tls: bool = False,
toolkit: str | None = None,
) -> list[ToolDefinition]:
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
return client.tools.list_tools(toolkit=toolkit)
def get_tool_messages(choice: dict) -> list[dict]:
@ -199,96 +217,26 @@ def apply_config_overrides(
config.engine.tls = tls_input
def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool = False) -> None:
"""
Display evaluation results in a format inspired by pytest's output.
def log_engine_health(client: Arcade) -> None:
try:
client.health.check()
Args:
results: List of dictionaries containing evaluation results for each model.
show_details: Whether to show detailed results for each case.
"""
total_passed = 0
total_failed = 0
total_warned = 0
total_cases = 0
for eval_suite in results:
for model_results in eval_suite:
model = model_results.get("model", "Unknown Model")
rubric = model_results.get("rubric", "Unknown Rubric")
cases = model_results.get("cases", [])
total_cases += len(cases)
console.print(f"[bold]Model:[/bold] [bold magenta]{model}[/bold magenta]")
if show_details:
console.print(f"[bold magenta]{rubric}[/bold magenta]")
for case in cases:
evaluation = case["evaluation"]
status = (
"[green]PASSED[/green]"
if evaluation.passed
else "[yellow]WARNED[/yellow]"
if evaluation.warning
else "[red]FAILED[/red]"
)
if evaluation.passed:
total_passed += 1
elif evaluation.warning:
total_warned += 1
else:
total_failed += 1
# Display one-line summary for each case
console.print(f"{status} {case['name']} -- Score: {evaluation.score:.2f}")
if show_details:
# Show detailed information for each case
console.print(f"[bold]User Input:[/bold] {case['input']}\n")
console.print("[bold]Details:[/bold]")
console.print(_format_evaluation(evaluation))
console.print("-" * 80)
# Summary
summary = (
f"[bold]Summary -- [/bold]Total: {total_cases} -- [green]Passed: {total_passed}[/green]"
)
if total_warned > 0:
summary += f" -- [yellow]Warnings: {total_warned}[/yellow]"
if total_failed > 0:
summary += f" -- [red]Failed: {total_failed}[/red]"
console.print(summary + "\n")
def _format_evaluation(evaluation: "EvaluationResult") -> str:
"""
Format evaluation results with color-coded matches and scores.
Args:
evaluation: An EvaluationResult object containing the evaluation results.
Returns:
A formatted string representation of the evaluation details.
"""
result_lines = []
if evaluation.failure_reason:
result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}")
else:
for critic_result in evaluation.results:
match_color = "green" if critic_result["match"] else "red"
field = critic_result["field"]
score = critic_result["score"]
weight = critic_result["weight"]
expected = critic_result["expected"]
actual = critic_result["actual"]
result_lines.append(
f"[bold]{field}:[/bold] "
f"[{match_color}]Match: {critic_result['match']}, "
f"Score: {score:.2f}/{weight:.2f}[/{match_color}]"
f"\n Expected: {expected}"
f"\n Actual: {actual}"
)
return "\n".join(result_lines)
except EngineNotHealthyError as e:
console.print(
"[bold][yellow]⚠️ Warning: "
+ str(e)
+ " ("
+ "[/yellow]"
+ "[red]"
+ str(e.status_code)
+ "[/red]"
+ "[yellow])[/yellow][/bold]"
)
except EngineOfflineError:
console.print(
"⚠️ Warning: Arcade Engine was unreachable. (Is it running?)",
style="bold yellow",
)
@dataclass

View file

@ -166,6 +166,17 @@ class ToolResource(BaseResource[ClientT]):
)
return AuthResponse(**data)
def list_tools(self, toolkit: str | None = None) -> list[ToolDefinition]:
"""
List the tools available for a given toolkit and provider.
"""
data = self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"{self._resource_path}/list",
params={"toolkit": toolkit},
)
return [ToolDefinition(**tool) for tool in data]
class HealthResource(BaseResource[ClientT]):
"""Health check resource."""
@ -331,6 +342,17 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
)
return AuthResponse(**data)
async def list_tools(self, toolkit: str | None = None) -> list[ToolDefinition]:
"""
List the tools available for a given toolkit and provider.
"""
data = await self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"{self._resource_path}/list",
params={"toolkit": toolkit},
)
return [ToolDefinition(**tool) for tool in data]
class AsyncHealthResource(BaseResource[AsyncArcadeClient]):
"""Asynchronous Health check resource."""

View file

@ -25,6 +25,7 @@ from pydantic_core import PydanticUndefined
from arcade.core.errors import ToolDefinitionError
from arcade.core.schema import (
TOOL_NAME_SEPARATOR,
FullyQualifiedName,
InputParameter,
OAuth2Requirement,
@ -215,6 +216,45 @@ class ToolCatalog(BaseModel):
return tool.definition
raise ValueError(f"Tool {func} not found in the catalog.")
def get_tool_by_name(
self, name: str, version: Optional[str] = None, separator: str = TOOL_NAME_SEPARATOR
) -> MaterializedTool:
"""
Get a tool from the catalog by name, optionally including the toolkit name.
Args:
name: The name of the tool, potentially including the toolkit name separated by the `separator`.
version: The version of the toolkit. Defaults to None.
separator: The separator between toolkit and tool names. Defaults to `TOOL_NAME_SEPARATOR`.
Returns:
MaterializedTool: The matching tool from the catalog.
Raises:
ValueError: If the tool is not found in the catalog.
"""
if separator in name:
toolkit_name, tool_name = name.split(separator, 1)
fq_name = FullyQualifiedName(
name=tool_name, toolkit_name=toolkit_name, toolkit_version=version
)
return self.get_tool(fq_name)
else:
# No toolkit name provided, search tools with matching tool name
matching_tools = [
tool
for fq_name, tool in self._tools.items()
if fq_name.name.lower() == name.lower()
and (
version is None
or (fq_name.toolkit_version or "").lower() == (version or "").lower()
)
]
if matching_tools:
return matching_tools[0]
raise ValueError(f"Tool {name} not found in the catalog.")
def get_tool(self, name: FullyQualifiedName) -> MaterializedTool:
"""
Get a tool from the catalog by fully-qualified name and version.

View file

@ -1,10 +1,11 @@
from .critic import BinaryCritic, NumericCritic, SimilarityCritic
from .critic import BinaryCritic, DatetimeCritic, NumericCritic, SimilarityCritic
from .eval import EvalRubric, EvalSuite, ExpectedToolCall, tool_eval
__all__ = [
"BinaryCritic",
"SimilarityCritic",
"NumericCritic",
"DatetimeCritic",
"EvalRubric",
"EvalSuite",
"ExpectedToolCall",

View file

@ -1,7 +1,11 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, ClassVar
import pytz
from dateutil import parser
from arcade.sdk.error import WeightError
@ -47,6 +51,17 @@ class BinaryCritic(Critic):
Raises:
TypeError: If the casting is not possible.
"""
# In case both are strings.
if actual == "None":
actual = None
if expected == "None":
expected = None
if expected is None:
# No need to cast; return actual as is
return actual
if actual is None:
# No need to cast; return None
return None
expected_type = type(expected)
try:
return expected_type(actual)
@ -60,14 +75,18 @@ class BinaryCritic(Critic):
Evaluates whether the expected and actual values are exactly equal after casting.
Args:
expected (Any): The expected value.
actual (Any): The actual value to compare, cast to the type of expected.
expected: The expected value.
actual: The actual value to compare, cast to the type of expected.
Returns:
dict[str, float | bool]: A dictionary containing the match status and score.
dict: A dictionary containing the match status and score.
"""
# Cast actual to the type of expected
actual_casted = self.cast_actual(expected, actual)
try:
actual_casted = self.cast_actual(expected, actual)
# TODO log or something better here
except TypeError:
actual_casted = actual
match = expected == actual_casted
return {"match": match, "score": self.weight if match else 0.0}
@ -187,3 +206,68 @@ class SimilarityCritic(Critic):
"match": similarity >= self.similarity_threshold,
"score": min(similarity * self.weight, self.weight),
}
@dataclass
@dataclass
class DatetimeCritic(Critic):
"""
A critic that evaluates the closeness of datetime values within a specified tolerance.
Attributes:
tolerance: Acceptable timedelta between expected and actual datetimes.
max_difference: Maximum timedelta for a partial score.
"""
critic_field: str
weight: float
tolerance: timedelta = timedelta(seconds=500)
max_difference: timedelta = timedelta(hours=2)
def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]:
"""Evaluates the closeness of datetime values within a specified tolerance."""
# Attempt to parse expected and actual datetime strings
try:
expected_dt = parser.parse(expected)
actual_dt = parser.parse(actual)
except (ValueError, TypeError):
# If parsing fails, return score 0
return {"match": False, "score": 0.0}
# Handle cases based on presence of tzinfo
if expected_dt.tzinfo is None and actual_dt.tzinfo is None:
# Both datetimes are naive, compare directly
time_diff_seconds = abs((expected_dt - actual_dt).total_seconds())
elif expected_dt.tzinfo is not None and actual_dt.tzinfo is not None:
# Both datetimes have tzinfo, compare in UTC
expected_utc = expected_dt.astimezone(pytz.utc)
actual_utc = actual_dt.astimezone(pytz.utc)
time_diff_seconds = abs((expected_utc - actual_utc).total_seconds())
else:
# One datetime has tzinfo and the other doesn't
# Compare naive datetime with the other's naive equivalent
if expected_dt.tzinfo is not None:
expected_naive = expected_dt.replace(tzinfo=None)
time_diff_seconds = abs((expected_naive - actual_dt).total_seconds())
else:
actual_naive = actual_dt.replace(tzinfo=None)
time_diff_seconds = abs((expected_dt - actual_naive).total_seconds())
# Convert tolerances to seconds
tolerance_seconds = self.tolerance.total_seconds()
max_difference_seconds = self.max_difference.total_seconds()
if time_diff_seconds <= tolerance_seconds:
# Full score if within tolerance
return {"match": True, "score": self.weight}
elif time_diff_seconds >= max_difference_seconds:
# No score if beyond max_difference
return {"match": False, "score": 0.0}
else:
# Partial score based on time difference
ratio = 1 - (time_diff_seconds / max_difference_seconds)
# Ensure ratio is not negative
ratio = max(ratio, 0)
score = self.weight * ratio
return {"match": False, "score": score}

View file

@ -1,11 +1,12 @@
import asyncio
import functools
import inspect
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from arcade.core.config_model import Config
from arcade.core.schema import FullyQualifiedName
from arcade.core.schema import TOOL_NAME_SEPARATOR
try:
import numpy as np
@ -218,7 +219,10 @@ class EvalCase:
expected_count = len(self.expected_tool_calls)
return self.rubric.fail_on_tool_call_quantity and expected_count != actual_count
def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> EvaluationResult:
def evaluate(
self,
actual_tool_calls: list[tuple[str, dict[str, Any]]],
) -> EvaluationResult:
"""
Evaluate the actual tool calls against the expected tool calls and critics.
@ -229,13 +233,13 @@ class EvalCase:
An EvaluationResult object containing the evaluation results.
"""
evaluation_result = EvaluationResult()
actual_tools = [tool for tool, _ in actual_tool_calls]
actual_tools = [tool_name for tool_name, _ in actual_tool_calls]
actual_count = len(actual_tool_calls)
if self.check_tool_call_quantity_failure(actual_count):
evaluation_result.score = 0.0
evaluation_result.passed = False
evaluation_result.warning = False
expected_count = len(self.expected_tool_calls)
expected_tool_names = ", ".join(
tool_call.name for tool_call in self.expected_tool_calls
@ -246,35 +250,27 @@ class EvalCase:
)
return evaluation_result
# check if no tools should be called and none were called
if not self.expected_tool_calls and not actual_tools:
evaluation_result.score = 1.0
evaluation_result.passed = True
evaluation_result.warning = False
return evaluation_result
if self.check_tool_selection_failure(actual_tools):
evaluation_result.score = 0.0
evaluation_result.passed = False
evaluation_result.warning = False
expected_tools = [tc.name for tc in self.expected_tool_calls]
evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}"
return evaluation_result
# if no critics for tool call arguments, then return
# passing score as only tool selection and quantity is checked
if not self.critics or len(self.critics) == 0:
if not self.critics:
evaluation_result.score = 1.0
evaluation_result.passed = True
evaluation_result.warning = False
# TODO passing reason should be added
return evaluation_result
# Create a cost matrix for the assignment problem
cost_matrix = self._create_cost_matrix(actual_tool_calls)
cost_matrix = self._create_cost_matrix(actual_tool_calls, self.expected_tool_calls)
# Use the Linear Sum Assignment (LSA) algorithm to find the optimal assignment
# The algorithm maximizes the total score of the assignment
# Use the Linear Sum Assignment algorithm to find the optimal assignment
row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True)
total_score = 0.0
@ -283,10 +279,11 @@ class EvalCase:
for i, j in zip(row_ind, col_ind):
if i < len(self.expected_tool_calls) and j < len(actual_tool_calls):
expected = self.expected_tool_calls[i]
actual_tool, actual_args = actual_tool_calls[j]
actual_name, actual_args = actual_tool_calls[j]
# Tool selection
tool_selection_score = evaluation_result.score_tool_selection(
expected.name, actual_tool, self.rubric.tool_selection_weight
expected.name, actual_name, self.rubric.tool_selection_weight
)
total_score += tool_selection_score
total_weight += self.rubric.tool_selection_weight
@ -295,32 +292,35 @@ class EvalCase:
for critic in self.critics:
expected_value = expected.args.get(critic.critic_field)
actual_value = actual_args.get(critic.critic_field)
if expected_value is not None and actual_value is not None:
try:
result = critic.evaluate(expected_value, actual_value)
total_score += result["score"]
total_weight += critic.weight
evaluation_result.add(
critic.critic_field,
result,
critic.weight,
expected_value,
actual_value,
)
except Exception as e:
print(
f"Critic evaluation failed for field '{critic.critic_field}': {e}"
)
# Depending on requirements, you might want to continue or handle differently
continue
# Compute the final score using the method from EvaluationResult
try:
result = critic.evaluate(expected_value, actual_value)
total_score += result["score"]
total_weight += critic.weight
evaluation_result.add(
critic.critic_field,
result,
critic.weight,
expected_value,
actual_value,
)
except Exception as e:
# TODO: log or console
print(f"Critic evaluation failed for field '{critic.critic_field}': {e}")
evaluation_result.add(
critic.critic_field,
{"match": False, "score": 0.0},
critic.weight,
expected_value,
actual_value,
)
continue
# Compute the final score
evaluation_result.compute_final_score(total_weight)
# Set the pass/fail status based on the fail_threshold
# Set pass/fail and warning status
evaluation_result.passed = evaluation_result.score >= self.rubric.fail_threshold
# Set the warning status based on the warn_threshold
evaluation_result.warning = (
not evaluation_result.passed and evaluation_result.score >= self.rubric.warn_threshold
)
@ -328,103 +328,52 @@ class EvalCase:
return evaluation_result
def _create_cost_matrix(
self, actual_tool_calls: list[tuple[str, dict[str, Any]]]
self,
actual_tool_calls: list[tuple[str, dict[str, Any]]],
expected_tool_calls: list[ExpectedToolCall],
) -> np.ndarray:
"""
Create a cost matrix for the Hungarian algorithm.
This method computes the score for each possible pairing of expected and actual tool calls.
The resulting matrix is used by the Hungarian algorithm to find the optimal assignment.
Create a cost matrix for the assignment problem.
Args:
actual_tool_calls: A list of tuples containing the actual tool calls and their arguments.
actual_tool_calls: A list of tuples of actual tool calls.
expected_tool_calls: A list of ExpectedToolCall instances.
Returns:
A numpy array representing the cost matrix.
"""
num_expected = len(self.expected_tool_calls)
num_expected = len(expected_tool_calls)
num_actual = len(actual_tool_calls)
n = max(num_expected, num_actual)
# Initialize a score matrix with zeros
score_matrix = np.zeros((n, n))
cost_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
if i < num_expected and j < num_actual:
expected = self.expected_tool_calls[i]
expected_tool = expected.name
expected_args = expected.args
actual_tool, actual_args = actual_tool_calls[j]
expected = expected_tool_calls[i]
actual_name, actual_args = actual_tool_calls[j]
score = 0.0
# Tool selection
if compare_tool_name(expected_tool, actual_tool):
if compare_tool_name(expected.name, actual_name):
score += self.rubric.tool_selection_weight
# Critics evaluation
if self.critics:
for critic in self.critics:
expected_value = expected_args.get(critic.critic_field)
actual_value = actual_args.get(critic.critic_field)
if expected_value is not None and actual_value is not None:
try:
result = critic.evaluate(expected_value, actual_value)
score += result.get("score", 0.0)
except Exception as e:
print(
f"Critic evaluation failed for field '{critic.critic_field}': {e}"
)
continue
for critic in self.critics: # type: ignore[union-attr]
expected_value = expected.args.get(critic.critic_field)
actual_value = actual_args.get(critic.critic_field)
if expected_value is not None and actual_value is not None:
try:
result = critic.evaluate(expected_value, actual_value)
score += result.get("score", 0.0)
except Exception as e:
print(
f"Critic evaluation failed for field '{critic.critic_field}': {e}"
)
cost_matrix[i, j] = score
score_matrix[i, j] = score
else:
# Assign a score of 0 for dummy assignments
score_matrix[i, j] = 0.0
return score_matrix
async def run(
self, client: AsyncArcade, model: str, tool_names: list[FullyQualifiedName]
) -> dict[str, Any]:
"""
Run the evaluation case asynchronously.
Args:
client: The AsyncArcade client instance.
model: The model to evaluate.
tool_names: The list of tool names to use for the evaluation.
Returns:
A dictionary containing the evaluation result for the case.
"""
messages = [{"role": "system", "content": self.system_message}]
messages.extend(list(self.additional_messages))
messages.append({"role": "user", "content": self.user_message})
response = await client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=messages,
tool_choice="auto",
tools=(str(name) for name in tool_names),
user="eval_user",
stream=False,
)
predicted_args = get_tool_args(response)
evaluation = self.evaluate(predicted_args)
result = {
"name": self.name,
"input": self.user_message,
"expected_tool_calls": [
{"name": tc.name, "args": tc.args} for tc in self.expected_tool_calls
],
"predicted_tool_calls": [{"name": tool, "args": args} for tool, args in predicted_args],
"evaluation": evaluation,
}
return result
return cost_matrix
@dataclass
@ -467,19 +416,19 @@ class EvalSuite:
Args:
name: The name of the evaluation case.
user_message: The user's input message.
system_message: The system message to be sent to the AI model.
expected_tool_calls: A list of expected tool calls.
expected_tool_calls: A list of expected tool calls as tuples of (function, args).
critics: List of critics to evaluate the tool arguments.
system_message: The system message to be used.
rubric: The evaluation rubric for this case.
additional_messages: Optional list of additional messages for context.
"""
expected = [
ExpectedToolCall(
name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()),
args=args,
)
for func, args in expected_tool_calls
]
expected = []
for func, args in expected_tool_calls:
# Fill in default arguments here
args_with_defaults = self._fill_args_with_defaults(func, args)
tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name())
expected.append(ExpectedToolCall(name=tool_name, args=args_with_defaults))
case = EvalCase(
name=name,
system_message=system_message or self.system_message,
@ -491,6 +440,30 @@ class EvalSuite:
)
self.cases.append(case)
def _fill_args_with_defaults(
self, func: Callable, provided_args: dict[str, Any]
) -> dict[str, Any]:
"""
Fill in default arguments for a tool function.
Args:
func: The tool function.
provided_args: The provided arguments.
Returns:
A dictionary with default arguments filled in.
"""
sig = inspect.signature(func)
args_with_defaults = {}
for param in sig.parameters.values():
if param.name in provided_args:
args_with_defaults[param.name] = provided_args[param.name]
elif param.default is not inspect.Parameter.empty:
args_with_defaults[param.name] = param.default
else:
args_with_defaults[param.name] = None # or raise an error
return args_with_defaults
def extend_case(
self,
name: str,
@ -528,13 +501,12 @@ class EvalSuite:
expected = last_case.expected_tool_calls
if expected_tool_calls:
expected = [
ExpectedToolCall(
name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()),
args=args,
)
for func, args in expected_tool_calls
]
expected = []
for func, args in expected_tool_calls:
# Fill in default arguments here
args_with_defaults = self._fill_args_with_defaults(func, args)
tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name())
expected.append(ExpectedToolCall(name=tool_name, args=args_with_defaults))
# Create a new case, copying from the last one and updating fields
new_case = EvalCase(
@ -550,9 +522,10 @@ class EvalSuite:
async def run(self, client: AsyncArcade, model: str) -> dict[str, Any]:
"""
Run the evaluation suite asynchronously.
Run the evaluation suite.
Args:
client: The AsyncArcade client instance.
model: The model to evaluate.
Returns:
@ -565,7 +538,48 @@ class EvalSuite:
async def sem_task(case: EvalCase) -> dict[str, Any]:
async with semaphore:
return await case.run(client, model, tool_names)
# Prepare messages
messages = [{"role": "system", "content": case.system_message}]
messages.extend(case.additional_messages)
messages.append({"role": "user", "content": case.user_message})
# Get the model response
response = await client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=messages,
tool_choice="auto",
tools=(str(name) for name in tool_names),
user="eval_user",
stream=False,
)
# Extract and fill default arguments for actual tool calls
predicted_args = get_tool_args(response)
filled_actual_tool_calls = []
for tool_name, args in predicted_args:
tool = self.catalog.get_tool_by_name(tool_name)
if tool is None:
raise ValueError(f"Tool '{tool_name}' not found in catalog.")
func = tool.tool
args_with_defaults = self._fill_args_with_defaults(func, args)
filled_actual_tool_calls.append((tool_name, args_with_defaults))
# Evaluate the case
evaluation = case.evaluate(filled_actual_tool_calls)
# Prepare the result
result = {
"name": case.name,
"input": case.user_message,
"expected_tool_calls": [
{"name": tc.name, "args": tc.args} for tc in case.expected_tool_calls
],
"predicted_tool_calls": [
{"name": name, "args": args} for name, args in filled_actual_tool_calls
],
"evaluation": evaluation,
}
return result
tasks = [sem_task(case) for case in self.cases]
case_results = await asyncio.gather(*tasks)
@ -589,7 +603,7 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
if message.tool_calls:
for tool_call in message.tool_calls:
tool_args_list.append((
tool_call.function.name,
normalize_name(tool_call.function.name),
json.loads(tool_call.function.arguments),
))
return tool_args_list
@ -597,17 +611,31 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
def compare_tool_name(expected: str, actual: str) -> bool:
"""
Compare the tool name without penalizing for mismatch in separators
between module names and tool names ex. '-' vs '_' vs '.' vs ' '
"""
# TODO optimize this
# Remove all separators from both names
separators = "-_."
expected_clean = "".join(char for char in expected if char not in separators)
actual_clean = "".join(char for char in actual if char not in separators)
Compare the tool names by replacing all separators with the TOOL_NAME_SEPARATOR
and comparing the normalized names.
# Compare the cleaned names
return expected_clean.lower() == actual_clean.lower()
Converts names like 'Google_ListEmails' to 'Google.ListEmails' if
TOOL_NAME_SEPARATOR is '.'.
Args:
expected: The expected tool name.
actual: The actual tool name.
Returns:
True if the normalized tool names match, False otherwise.
"""
separators = "-_."
expected_normalized = normalize_name(expected, separators)
actual_normalized = normalize_name(actual, separators)
return expected_normalized.lower() == actual_normalized.lower()
def normalize_name(name: str, separators: str = "-_.") -> str:
for sep in separators:
if sep != TOOL_NAME_SEPARATOR:
name = name.replace(sep, TOOL_NAME_SEPARATOR)
return name
def tool_eval() -> Callable[[Callable], Callable]:

View file

@ -7,3 +7,5 @@ coverage:
default:
target: 90%
threshold: 0.5%
exclude:
- arcade/cli/**

View file

@ -30,10 +30,12 @@ uvicorn = {version = "^0.30.0", optional = true}
scipy = {version = "^1.14.0", optional = true}
numpy = {version = "^2.0.0", optional = true}
scikit-learn = {version = "^1.5.0", optional = true}
pytz = {version = "^2024.1", optional = true}
python-dateutil = {version = "^2.8.2", optional = true}
[tool.poetry.extras]
fastapi = ["fastapi", "uvicorn", "opentelemetry-instrumentation-fastapi", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-common"]
evals = ["scipy", "numpy", "scikit-learn"]
evals = ["scipy", "numpy", "scikit-learn", "pytz", "python-dateutil"]
[tool.poetry.group.dev.dependencies]
pytest = "^8.1.2"
@ -43,6 +45,8 @@ pre-commit = "^3.4.0"
tox = "^4.11.1"
pytest-asyncio = "^0.23.7"
types-toml = "^0.10.8"
types-pytz = "^2024.1"
types-python-dateutil = "^2.8.2"
poetry-plugin-export = "^1.7.0"
[tool.poetry.scripts]
@ -63,9 +67,11 @@ ignore_missing_imports = "True"
[tool.pytest.ini_options]
testpaths = ["tests"]
[tool.coverage.report]
skip_empty = true
[tool.coverage.run]
branch = true
source = ["arcade"]
omit = ["arcade/cli/*"]
[tool.coverage.report]
skip_empty = true

View file

@ -373,3 +373,24 @@ async def test_async_arcade_health_check_raises_error(
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
with pytest.raises(EngineNotHealthyError):
await test_async_client.health.check()
def test_arcade_tool_list_tools(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.tools.list_tools method."""
data = [TOOL_DEFINITION_DATA]
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: data)
tool_definitions = test_sync_client.tools.list_tools(toolkit="TestToolkit")
assert tool_definitions == [ToolDefinition(**TOOL_DEFINITION_DATA)]
@pytest.mark.asyncio
async def test_async_arcade_tool_list_tools(test_async_client, mock_async_response, monkeypatch):
"""Test AsyncArcade.tools.list_tools method."""
data = [TOOL_DEFINITION_DATA]
async def mock_execute_request(*args, **kwargs):
return data
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
tool_definitions = await test_async_client.tools.list_tools(toolkit="TestToolkit")
assert tool_definitions == [ToolDefinition(**TOOL_DEFINITION_DATA)]

View file

@ -68,7 +68,6 @@ def test_get_tool(toolkit_version: str | None, expected_tool):
name="SampleTool", toolkit_name="SampleToolkit", toolkit_version=toolkit_version
)
tool = catalog.get_tool(fq_name)
assert tool.tool == expected_tool
@ -102,3 +101,38 @@ def test_add_toolkit_type_error():
assert "Type error encountered while adding tool invalid_tool from mock_module" in str(
exc_info.value
)
def test_get_tool_by_name():
catalog = ToolCatalog()
catalog.add_tool(sample_tool, "sample_toolkit")
tool = catalog.get_tool_by_name("SampleToolkit.SampleTool")
assert tool.tool == sample_tool
assert tool.name == "SampleTool"
assert tool.meta.toolkit == "sample_toolkit"
assert tool.version is None
with pytest.raises(ValueError):
catalog.get_tool_by_name("nonexistent_toolkit.SampleTool")
def test_get_tool_by_name_with_version():
catalog = ToolCatalog()
catalog.add_tool(sample_tool, "sample_toolkit")
tool = catalog.get_tool_by_name("SampleToolkit.SampleTool")
assert tool.tool == sample_tool
assert tool.name == "SampleTool"
assert tool.meta.toolkit == "sample_toolkit"
with pytest.raises(ValueError):
catalog.get_tool_by_name("SampleToolkit.SampleTool", version="2.0.0")
def test_get_tool_by_name_with_invalid_version():
catalog = ToolCatalog()
catalog.add_tool(sample_tool, "SampleToolkit")
with pytest.raises(ValueError):
catalog.get_tool_by_name("SampleToolkit.SampleTool", version="2.0.0")

View file

@ -1,8 +1,13 @@
from datetime import timedelta
import pytest
import pytz
from dateutil import parser
from arcade.sdk.error import WeightError
from arcade.sdk.eval import (
BinaryCritic,
DatetimeCritic,
EvalRubric,
ExpectedToolCall,
NumericCritic,
@ -255,19 +260,13 @@ def test_eval_case_multiple_critics():
# Test EvalCase with missing expected and actual values in args
@pytest.mark.parametrize(
"expected_args, actual_args, expected_score",
[
({"param": "value"}, {}, 1.0), # Missing actual value
({}, {"param": "value"}, 1.0), # Missing expected value
({"param": "value"}, {"param": "value"}, 2.0), # Both values present
],
)
def test_eval_case_missing_values(expected_args, actual_args, expected_score):
def test_eval_case_with_none_values():
"""
Test that when either expected or actual values are missing for a critic,
the critic evaluation is skipped, and the total score is computed accordingly.
Test that when expected or actual values are None, the critic evaluates them appropriately.
"""
expected_args = {"param": None}
actual_args = {"param": None}
expected_tool_calls = [ExpectedToolCall(name="ToolA", args=expected_args)]
actual_tool_calls = [("ToolA", actual_args)]
@ -284,15 +283,8 @@ def test_eval_case_missing_values(expected_args, actual_args, expected_score):
result = case.evaluate(actual_tool_calls)
# If critic is skipped, only tool selection score is counted
# Otherwise, tool selection + critic score
total_weight = 1.0 # At least tool selection weight
if "param" in expected_args and "param" in actual_args:
total_weight += 1.0 # Critic weight
expected_total_score = expected_score / total_weight
assert result.score == expected_total_score
# Both values are None, so the critic should return a match
assert result.score == 2.0 / 2.0 # Full score (tool selection + critic score)
# Test that WeightError is raised for invalid critic weights
@ -340,3 +332,136 @@ def test_similarity_critic_unsupported_metric():
"""
with pytest.raises(ValueError):
SimilarityCritic(critic_field="text", weight=1.0, metric="unsupported_metric")
# Test DatetimeCritic
# Parameterized tests for DatetimeCritic with various datetime formats and default timezones
@pytest.mark.parametrize(
"critic_params, expected, actual, expected_match, expected_score",
[
# Test with time component and timezone
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00-07:00",
"2024-09-26T12:00:00-07:00",
True,
1.0,
),
# Test without time component (dates only)
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26",
"2024-09-26",
True,
1.0,
),
# Test with and without timezone (assumes UTC)
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00Z",
"2024-09-26T12:00:00",
True,
1.0,
),
# Test naive datetimes
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00",
"2024-09-26T12:00:00",
True,
1.0,
),
],
)
def test_datetime_critic_basic(critic_params, expected, actual, expected_match, expected_score):
"""
Test DatetimeCritic with various datetime formats and default timezones.
"""
critic = DatetimeCritic(**critic_params)
result = critic.evaluate(expected, actual)
assert result["match"] == expected_match
assert result["score"] == expected_score
# Parameterized tests for DatetimeCritic's handling of tolerances and max differences
@pytest.mark.parametrize(
"critic_params, expected, actual, expected_match, expected_score_func",
[
# Test time difference within tolerance
(
{"critic_field": "start_datetime", "weight": 1.0, "tolerance": timedelta(seconds=60)},
"2024-09-26T12:00:00",
"2024-09-26T12:00:30",
True,
lambda critic: critic.weight,
),
# Test time difference outside tolerance but within max_difference
(
{
"critic_field": "start_datetime",
"weight": 1.0,
"tolerance": timedelta(seconds=60),
"max_difference": timedelta(minutes=5),
},
"2024-09-26T12:00:00",
"2024-09-26T12:04:00",
False,
lambda critic: critic.weight * (1 - (240 / 300)),
),
# Test time difference exceeds max_difference
(
{
"critic_field": "start_datetime",
"weight": 1.0,
"max_difference": timedelta(minutes=5),
},
"2024-09-26T12:00:00",
"2024-09-26T12:10:00",
False,
lambda critic: 0.0,
),
],
)
def test_datetime_critic_tolerances(
critic_params, expected, actual, expected_match, expected_score_func
):
"""
Test DatetimeCritic's handling of tolerances and max differences.
"""
critic = DatetimeCritic(**critic_params)
result = critic.evaluate(expected, actual)
assert result["match"] == expected_match
expected_score = expected_score_func(critic)
assert pytest.approx(result["score"], abs=1e-6) == expected_score
def test_datetime_critic_naive_and_timezone_aware():
"""
Test DatetimeCritic when comparing naive and timezone-aware datetimes.
"""
critic = DatetimeCritic(critic_field="start_datetime", weight=1.0)
expected = "2024-09-26T12:00:00Z"
actual = "2024-09-26T07:00:00"
result = critic.evaluate(expected, actual)
assert result["match"] is False
# Compute expected score based on time difference
expected_dt = parser.parse(expected)
actual_dt = parser.parse(actual)
if actual_dt.tzinfo is None:
actual_dt = pytz.utc.localize(actual_dt)
if expected_dt.tzinfo is None:
expected_dt = pytz.utc.localize(expected_dt)
time_diff_seconds = abs((expected_dt - actual_dt).total_seconds())
if time_diff_seconds <= critic.tolerance.total_seconds():
expected_score = critic.weight
elif time_diff_seconds >= critic.max_difference.total_seconds():
expected_score = 0.0
else:
ratio = 1 - (time_diff_seconds / critic.max_difference.total_seconds())
expected_score = critic.weight * ratio
assert pytest.approx(result["score"], abs=1e-6) == expected_score

View file

@ -48,7 +48,7 @@ async def test_error_responses(
if status_code == 422:
await list_org_repositories(mock_context, "org", repo_type=RepoType.ALL)
elif status_code == 301:
await count_stargazers("owner", "repo")
await count_stargazers(mock_context, "owner", "repo")
elif status_code == 404:
await list_org_repositories(mock_context, "non_existent_org")
elif status_code == 503:
@ -66,8 +66,8 @@ async def test_list_repository_activities_invalid_cursor(mock_context, mock_clie
@pytest.mark.asyncio
async def test_count_stargazers_success(mock_client):
async def test_count_stargazers_success(mock_context, mock_client):
mock_client.get.return_value = Response(200, json={"stargazers_count": 42})
result = await count_stargazers("owner", "repo")
assert result == "The repository owner/repo has 42 stargazers."
result = await count_stargazers(mock_context, "owner", "repo")
assert result == 42

View file

@ -1,8 +1,12 @@
import arcade_github
from arcade_github.tools.models import DiffSide, ReviewCommentSubjectType # Add these imports
from arcade_github.tools.models import (
DiffSide,
ReviewCommentSubjectType,
SortDirection,
)
from arcade_github.tools.pull_requests import (
create_reply_for_review_comment,
create_review_comment, # Add this import
create_review_comment,
get_pull_request,
list_pull_request_commits,
list_pull_requests,
@ -169,7 +173,7 @@ def github_pull_requests_eval_suite() -> EvalSuite:
"repo": "test",
"pull_number": 72,
"sort": "updated",
"direction": "asc",
"direction": SortDirection.ASC,
},
)
],

View file

@ -1,4 +1,5 @@
import arcade_github
from arcade_github.tools.models import SortDirection
from arcade_github.tools.repositories import (
count_stargazers,
get_repository,
@ -66,7 +67,7 @@ def github_repositories_eval_suite() -> EvalSuite:
"org": "ArcadeAI",
"repo_type": "all",
"sort": "created",
"sort_direction": "desc",
"sort_direction": SortDirection.DESC,
},
)
],
@ -108,7 +109,7 @@ def github_repositories_eval_suite() -> EvalSuite:
{
"owner": "ArcadeAI",
"repo": "test",
"direction": "desc",
"direction": SortDirection.DESC,
"per_page": 30,
"actor": "TestUser",
"time_period": "month",
@ -138,7 +139,7 @@ def github_repositories_eval_suite() -> EvalSuite:
"owner": "ArcadeAI",
"repo": "test",
"sort": "created",
"direction": "desc",
"direction": SortDirection.DESC,
"per_page": 30,
"page": 1,
"include_extra_data": False,

View file

@ -1,17 +1,16 @@
from datetime import datetime, timedelta
from typing import Annotated
from zoneinfo import ZoneInfo
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from arcade.core.errors import RetryableToolError, ToolExecutionError
from arcade.core.errors import RetryableToolError
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import Google
from arcade_google.tools.models import Day, EventVisibility, SendUpdatesOptions, TimeSlot
from arcade_google.tools.utils import _update_datetime
from arcade_google.tools.models import EventVisibility, SendUpdatesOptions
from arcade_google.tools.utils import parse_datetime
@tool(
@ -25,97 +24,87 @@ from arcade_google.tools.utils import _update_datetime
async def create_event(
context: ToolContext,
summary: Annotated[str, "The title of the event"],
start_date: Annotated[Day, "The day that the event starts"],
start_time: Annotated[TimeSlot, "The time of the day that the event starts"],
end_date: Annotated[Day, "The day that the event ends"],
end_time: Annotated[TimeSlot, "The time of the day that the event ends"],
start_datetime: Annotated[
str,
"The datetime when the event starts in ISO 8601 format, e.g., '2024-12-31T15:30:00'.",
],
end_datetime: Annotated[
str,
"The datetime when the event ends in ISO 8601 format, e.g., '2024-12-31T17:30:00'.",
],
calendar_id: Annotated[
str, "The ID of the calendar to create the event in, usually 'primary'"
str, "The ID of the calendar to create the event in, usually 'primary'."
] = "primary",
description: Annotated[str | None, "The description of the event"] = None,
location: Annotated[str | None, "The location of the event"] = None,
visibility: Annotated[EventVisibility, "The visibility of the event"] = EventVisibility.DEFAULT,
attendee_emails: Annotated[
list[str] | None,
"The list of attendee emails. Must be valid email addresses e.g., username@domain.com",
"The list of attendee emails. Must be valid email addresses e.g., username@domain.com.",
] = None,
) -> Annotated[dict, "A dictionary containing the created event details"]:
"""Create a new event/meeting/sync/meetup in the specified calendar."""
service = build("calendar", "v3", credentials=Credentials(context.authorization.token))
try:
# Get the calendar's time zone
calendar = service.calendars().get(calendarId=calendar_id).execute()
time_zone = calendar["timeZone"]
# Get the calendar's time zone
calendar = service.calendars().get(calendarId=calendar_id).execute()
time_zone = calendar["timeZone"]
# Convert enum values to datetime objects
start_datetime = datetime.combine(start_date.to_date(time_zone), start_time.to_time())
end_datetime = datetime.combine(end_date.to_date(time_zone), end_time.to_time())
# Parse datetime strings
start_dt = parse_datetime(start_datetime, time_zone)
end_dt = parse_datetime(end_datetime, time_zone)
event = {
"summary": summary,
"description": description,
"location": location,
"start": {"dateTime": start_datetime.isoformat(), "timeZone": time_zone},
"end": {"dateTime": end_datetime.isoformat(), "timeZone": time_zone},
"visibility": visibility.value,
}
event = {
"summary": summary,
"description": description,
"location": location,
"start": {"dateTime": start_dt.isoformat(), "timeZone": time_zone},
"end": {"dateTime": end_dt.isoformat(), "timeZone": time_zone},
"visibility": visibility.value,
}
if attendee_emails:
event["attendees"] = [{"email": email} for email in attendee_emails]
if attendee_emails:
event["attendees"] = [{"email": email} for email in attendee_emails]
created_event = service.events().insert(calendarId=calendar_id, body=event).execute()
except HttpError as e:
raise ToolExecutionError(
f"HttpError during execution of '{create_event.__name__}' tool.", str(e)
)
except Exception as e:
raise ToolExecutionError(
f"Unexpected Error encountered during execution of '{create_event.__name__}' tool.",
str(e),
)
else:
return {"event": created_event}
created_event = service.events().insert(calendarId=calendar_id, body=event).execute()
return {"event": created_event}
@tool(
requires_auth=Google(
scopes=["https://www.googleapis.com/auth/calendar.events.readonly"],
scopes=[
"https://www.googleapis.com/auth/calendar.readonly",
"https://www.googleapis.com/auth/calendar.events",
],
)
)
async def list_events(
context: ToolContext,
min_day: Annotated[
Day, "Filter by events that end on or after this day. Combined with min_time_slot"
min_end_datetime: Annotated[
str,
"Filter by events that end on or after this datetime in ISO 8601 format, e.g., '2024-09-15T09:00:00'.",
],
min_time_slot: Annotated[
TimeSlot, "Filter by events that end after this time. Combined with min_day"
],
max_day: Annotated[
Day, "Filter by events that start on or before this day. Combined with max_time_slot"
],
max_time_slot: Annotated[
TimeSlot, "Filter by events that start before this time. Combined with max_day"
max_start_datetime: Annotated[
str,
"Filter by events that start before this datetime in ISO 8601 format, e.g., '2024-09-16T17:00:00'.",
],
calendar_id: Annotated[str, "The ID of the calendar to list events from"] = "primary",
max_results: Annotated[int, "The maximum number of events to return"] = 10,
) -> Annotated[dict, "A dictionary containing the list of events"]:
"""
List events from the specified calendar within the given date range.
List events from the specified calendar within the given datetime range.
min_day and min_time_slot are combined to form the lower bound (exclusive) for an event's end time to filter by
max_day and max_time_slot are combined to form the upper bound (exclusive) for an event's start time to filter by
min_end_datetime serves as the lower bound (exclusive) for an event's end time.
max_start_datetime serves as the upper bound (exclusive) for an event's start time.
For example:
If min_day is set to Day.TODAY and min_time_slot is set to TimeSlot._09:00,
and max_day is set to Day.TOMORROW and max_time_slot is set to TimeSlot._17:00,
If min_end_datetime is set to 2024-09-15T09:00:00 and max_start_datetime is set to 2024-09-16T17:00:00,
the function will return events that:
1. End after 09:00 today (exclusive)
2. Start before 17:00 tomorrow (exclusive)
This means an event starting at 08:00 today and ending at 10:00 today would be included,
but an event starting at 17:00 tomorrow would not be included.
1. End after 09:00 on September 15, 2024 (exclusive)
2. Start before 17:00 on September 16, 2024 (exclusive)
This means an event starting at 08:00 on September 15 and ending at 10:00 on September 15 would be included,
but an event starting at 17:00 on September 16 would not be included.
"""
service = build("calendar", "v3", credentials=Credentials(context.authorization.token))
@ -123,23 +112,19 @@ async def list_events(
calendar = service.calendars().get(calendarId=calendar_id).execute()
time_zone = calendar["timeZone"]
# Convert enum values to datetime with timezone offset
start_datetime = datetime.combine(
min_day.to_date(time_zone), min_time_slot.to_time()
).astimezone(ZoneInfo(time_zone))
end_datetime = datetime.combine(max_day.to_date(time_zone), max_time_slot.to_time()).astimezone(
ZoneInfo(time_zone)
)
# Parse datetime strings
min_end_dt = parse_datetime(min_end_datetime, time_zone)
max_start_dt = parse_datetime(max_start_datetime, time_zone)
if start_datetime > end_datetime:
start_datetime, end_datetime = end_datetime, start_datetime
if min_end_dt > max_start_dt:
min_end_dt, max_start_dt = max_start_dt, min_end_dt
events_result = (
service.events()
.list(
calendarId=calendar_id,
timeMin=start_datetime.isoformat(),
timeMax=end_datetime.isoformat(),
timeMin=min_end_dt.isoformat(),
timeMax=max_start_dt.isoformat(),
maxResults=max_results,
singleEvents=True,
orderBy="startTime",
@ -179,21 +164,16 @@ async def list_events(
async def update_event(
context: ToolContext,
event_id: Annotated[str, "The ID of the event to update"],
updated_start_day: Annotated[
Day | None,
"The updated day that the event starts. Combined with updated_start_time to form the new start time",
updated_start_datetime: Annotated[
str | None,
"The updated datetime that the event starts in ISO 8601 format, e.g., '2024-12-31T15:30:00'.",
] = None,
updated_start_time: Annotated[
TimeSlot | None,
"The updated time that the event starts. Combined with updated_start_day to form the new start time",
updated_end_datetime: Annotated[
str | None,
"The updated datetime that the event ends in ISO 8601 format, e.g., '2024-12-31T17:30:00'.",
] = None,
updated_end_day: Annotated[
Day | None,
"The updated day that the event ends. Combined with updated_end_time to form the new end time",
] = None,
updated_end_time: Annotated[TimeSlot | None, "The updated time that the event ends"] = None,
updated_calendar_id: Annotated[
str | None, "The updated ID of the calendar containing the event"
str | None, "The updated ID of the calendar containing the event."
] = None,
updated_summary: Annotated[str | None, "The updated title of the event"] = None,
updated_description: Annotated[str | None, "The updated description of the event"] = None,
@ -201,25 +181,24 @@ async def update_event(
updated_visibility: Annotated[EventVisibility | None, "The visibility of the event"] = None,
attendee_emails_to_add: Annotated[
list[str] | None,
"The list of updated attendee emails to add. Must be valid email addresses e.g., username@domain.com",
"The list of attendee emails to add. Must be valid email addresses e.g., username@domain.com.",
] = None,
attendee_emails_to_remove: Annotated[
list[str] | None,
"The list of attendee emails to remove. Must be valid email addresses e.g., username@domain.com",
"The list of attendee emails to remove. Must be valid email addresses e.g., username@domain.com.",
] = None,
send_updates: Annotated[
SendUpdatesOptions, "Guests who should receive notifications about the event update"
SendUpdatesOptions, "Should attendees be notified of the update? (none, all, external_only)"
] = SendUpdatesOptions.ALL,
) -> Annotated[
str,
"A string containing the updated event details, including the event ID, update timestamp, and a link to view the updated event",
"A string containing the updated event details, including the event ID, update timestamp, and a link to view the updated event.",
]:
"""
Update an existing event in the specified calendar with the provided details.
Only the provided fields will be updated; others will remain unchanged.
`updated_start_day` and `updated_start_time` must be provided together.
`updated_end_day` and `updated_end_time` must be provided together.
`updated_start_datetime` and `updated_end_datetime` are independent and can be provided separately.
"""
service = build("calendar", "v3", credentials=Credentials(context.authorization.token))
@ -228,13 +207,13 @@ async def update_event(
try:
event = service.events().get(calendarId="primary", eventId=event_id).execute()
except HttpError: # TODO: This is a first pass. We should do better.
except HttpError:
valid_events_with_id = (
service.events()
.list(
calendarId="primary",
timeMin=(datetime.now() - timedelta(days=2)).isoformat(),
timeMax=(datetime.now() - timedelta(days=2)).isoformat(),
timeMax=(datetime.now() + timedelta(days=365)).isoformat(),
maxResults=50,
singleEvents=True,
orderBy="startTime",
@ -243,14 +222,18 @@ async def update_event(
)
raise RetryableToolError(
f"Event with ID {event_id} not found.",
additional_prompt_content=f"Here is list of valid events. The event_id parameter must match one of these: {valid_events_with_id}",
additional_prompt_content=f"Here is a list of valid events. The event_id parameter must match one of these: {valid_events_with_id}",
retry_after_ms=1000,
developer_message=f"Event with ID {event_id} not found. Please try again with a valid event ID.",
)
update_fields = {
"start": _update_datetime(updated_start_day, updated_start_time, time_zone),
"end": _update_datetime(updated_end_day, updated_end_time, time_zone),
"start": {"dateTime": updated_start_datetime.isoformat(), "timeZone": time_zone}
if updated_start_datetime
else None,
"end": {"dateTime": updated_end_datetime.isoformat(), "timeZone": time_zone}
if updated_end_datetime
else None,
"calendarId": updated_calendar_id,
"sendUpdates": send_updates.value if send_updates else None,
"summary": updated_summary,
@ -265,12 +248,20 @@ async def update_event(
event["attendees"] = [
attendee
for attendee in event.get("attendees", [])
if attendee.get("email", "") not in attendee_emails_to_remove
if attendee.get("email", "").lower()
not in [email.lower() for email in attendee_emails_to_remove]
]
if attendee_emails_to_add:
event["attendees"] = event.get("attendees", []) + [
{"email": email} for email in attendee_emails_to_add
existing_emails = {
attendee.get("email", "").lower() for attendee in event.get("attendees", [])
}
new_attendees = [
{"email": email}
for email in attendee_emails_to_add
if email.lower() not in existing_emails
]
event["attendees"] = event.get("attendees", []) + new_attendees
updated_event = (
service.events()

View file

@ -1,8 +1,9 @@
import datetime
import re
from base64 import urlsafe_b64decode
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Optional
from zoneinfo import ZoneInfo
from bs4 import BeautifulSoup
from google.oauth2.credentials import Credentials
@ -11,6 +12,31 @@ from googleapiclient.discovery import build
from arcade_google.tools.models import Day, TimeSlot
def parse_datetime(datetime_str: str, time_zone: str) -> datetime:
"""
Parse a datetime string in ISO 8601 format and ensure it is timezone-aware.
Args:
datetime_str (str): The datetime string to parse. Expected format: 'YYYY-MM-DDTHH:MM:SS'.
time_zone (str): The timezone to apply if the datetime string is naive.
Returns:
datetime: A timezone-aware datetime object.
Raises:
ValueError: If the datetime string is not in the correct format.
"""
try:
dt = datetime.fromisoformat(datetime_str)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=ZoneInfo(time_zone))
except ValueError as e:
raise ValueError(
f"Invalid datetime format: '{datetime_str}'. Expected ISO 8601 format, e.g., '2024-12-31T15:30:00'."
) from e
return dt
class DateRange(Enum):
TODAY = "today"
YESTERDAY = "yesterday"
@ -21,24 +47,24 @@ class DateRange(Enum):
THIS_YEAR = "this_year"
def to_date_query(self):
today = datetime.datetime.now()
today = datetime.now()
result = "after:"
comparison_date = today
if self == DateRange.YESTERDAY:
comparison_date = today - datetime.timedelta(days=1)
comparison_date = today - timedelta(days=1)
elif self == DateRange.LAST_7_DAYS:
comparison_date = today - datetime.timedelta(days=7)
comparison_date = today - timedelta(days=7)
elif self == DateRange.LAST_30_DAYS:
comparison_date = today - datetime.timedelta(days=30)
comparison_date = today - timedelta(days=30)
elif self == DateRange.THIS_MONTH:
comparison_date = today.replace(day=1)
elif self == DateRange.LAST_MONTH:
comparison_date = (today.replace(day=1) - datetime.timedelta(days=1)).replace(day=1)
comparison_date = (today.replace(day=1) - timedelta(days=1)).replace(day=1)
elif self == DateRange.THIS_YEAR:
comparison_date = today.replace(month=1, day=1)
elif self == DateRange.LAST_MONTH:
comparison_date = (today.replace(month=1, day=1) - datetime.timedelta(days=1)).replace(
comparison_date = (today.replace(month=1, day=1) - timedelta(days=1)).replace(
month=1, day=1
)

View file

@ -1,10 +1,19 @@
from datetime import timedelta
import arcade_google
from arcade_google.tools.calendar import create_event, delete_event, list_events, update_event
from arcade_google.tools.models import Day, EventVisibility, TimeSlot
from arcade_google.tools.calendar import (
EventVisibility,
SendUpdatesOptions,
create_event,
delete_event,
list_events,
update_event,
)
from arcade.core.catalog import ToolCatalog
from arcade.sdk.eval import (
BinaryCritic,
DatetimeCritic,
EvalRubric,
EvalSuite,
tool_eval,
@ -16,11 +25,9 @@ rubric = EvalRubric(
warn_threshold=0.95,
)
catalog = ToolCatalog()
catalog.add_module(arcade_google)
history_after_list_events = [
{"role": "user", "content": "do i have any events on my calendar for today?"},
{
@ -36,7 +43,7 @@ history_after_list_events = [
"type": "function",
"function": {
"name": "Google_ListEvents",
"arguments": '{"max_day":"today","max_time_slot":"23:45","min_day":"today","min_time_slot":"00:00"}',
"arguments": '{"min_end_datetime":"2024-09-26T00:00:00-07:00","max_start_datetime":"2024-09-27T00:00:00-07:00"}',
},
}
],
@ -59,7 +66,9 @@ def calendar_eval_suite() -> EvalSuite:
"""Create an evaluation suite for Calendar tools."""
suite = EvalSuite(
name="Calendar Tools Evaluation",
system_message="You are an AI assistant that can create and list events using the provided tools.",
system_message=(
"You are an AI assistant that can create, list, update, and delete events using the provided tools. Today is 2024-09-26"
),
catalog=catalog,
rubric=rubric,
)
@ -67,32 +76,34 @@ def calendar_eval_suite() -> EvalSuite:
# Cases for create_event
suite.add_case(
name="Create calendar event",
user_message="Create a meeting for 'Team Meeting' starting next thursday from 11:45pm to 12:15am. Invite johndoe@example.com",
user_message=(
"Create a meeting for 'Team Meeting' starting on September 26, 2024, from 11:45pm to 12:15am. Invite johndoe@example.com"
),
expected_tool_calls=[
(
create_event,
{
"summary": "Team Meeting",
"start_date": Day.NEXT_THURSDAY.value,
"start_time": TimeSlot._2345.value,
"end_date": Day.NEXT_FRIDAY.value,
"end_time": TimeSlot._0015.value,
"start_datetime": "2024-09-26T23:45:00",
"end_datetime": "2024-09-27T00:15:00",
"calendar_id": "primary",
"attendee_emails": ["johndoe@example.com"],
"description": None,
"location": None,
"visibility": EventVisibility.DEFAULT,
"description": "Team Meeting",
},
)
],
critics=[
BinaryCritic(critic_field="summary", weight=0.15),
BinaryCritic(critic_field="start_date", weight=0.15),
BinaryCritic(critic_field="start_time", weight=0.15),
BinaryCritic(critic_field="end_date", weight=0.15),
BinaryCritic(critic_field="end_time", weight=0.15),
BinaryCritic(critic_field="attendee_emails", weight=0.15),
BinaryCritic(critic_field="summary", weight=0.2),
DatetimeCritic(
critic_field="start_datetime", weight=0.2, tolerance=timedelta(seconds=10)
),
DatetimeCritic(
critic_field="end_datetime", weight=0.2, tolerance=timedelta(seconds=10)
),
BinaryCritic(critic_field="attendee_emails", weight=0.2),
BinaryCritic(critic_field="description", weight=0.1),
BinaryCritic(critic_field="location", weight=0.1),
],
)
@ -104,49 +115,61 @@ def calendar_eval_suite() -> EvalSuite:
(
list_events,
{
"min_day": Day.TODAY.value,
"min_time_slot": TimeSlot._0000.value,
"max_day": Day.TOMORROW.value,
"max_time_slot": TimeSlot._0000.value,
"min_end_datetime": "2024-09-26T00:00:00",
"max_start_datetime": "2024-09-27T00:00:00",
"calendar_id": "primary",
"event_types": None,
"max_results": 10,
},
)
],
critics=[
BinaryCritic(critic_field="min_day", weight=0.1),
BinaryCritic(critic_field="min_time_slot", weight=0.1),
BinaryCritic(critic_field="max_day", weight=0.1),
BinaryCritic(critic_field="max_time_slot", weight=0.1),
BinaryCritic(critic_field="calendar_id", weight=0.1),
BinaryCritic(critic_field="event_types", weight=0.1),
BinaryCritic(critic_field="max_results", weight=0.1),
DatetimeCritic(
critic_field="min_end_datetime", weight=0.3, tolerance=timedelta(hours=1)
),
DatetimeCritic(
critic_field="max_start_datetime", weight=0.3, tolerance=timedelta(hours=1)
),
BinaryCritic(critic_field="calendar_id", weight=0.2),
BinaryCritic(critic_field="max_results", weight=0.2),
],
)
# Cases for update_event
suite.add_case(
name="Update a calendar event",
user_message="Oh no! I cant make it to the API Test since i have lunch with an old friend at that time. Change the meeting to 3pm to 4pm please.",
user_message=(
"Oh no! I can't make it to the API Test since I have lunch with an old friend at that time. "
"Change the meeting my meeting tomorrow at 3pm to 4pm. Let everyone know."
),
expected_tool_calls=[
(
update_event,
{
"event_id": "00099992228181818181",
"updated_start_day": Day.TODAY.value,
"updated_start_time": TimeSlot._1500.value,
"updated_end_day": Day.TODAY.value,
"updated_end_time": TimeSlot._1600.value,
"updated_start_datetime": "2024-09-27T16:00:00",
"updated_end_datetime": "2024-09-27T18:00:00",
"updated_calendar_id": "primary",
"updated_summary": "API Test",
"updated_description": "API Test",
"updated_location": "611 Gateway Blvd",
"updated_visibility": EventVisibility.DEFAULT,
"attendee_emails_to_add": None,
"attendee_emails_to_remove": None,
"send_updates": SendUpdatesOptions.ALL,
},
)
],
critics=[
BinaryCritic(critic_field="event_id", weight=0.2),
BinaryCritic(critic_field="updated_start_day", weight=0.1),
BinaryCritic(critic_field="updated_start_time", weight=0.1),
BinaryCritic(critic_field="updated_end_day", weight=0.1),
BinaryCritic(critic_field="updated_end_time", weight=0.1),
BinaryCritic(critic_field="event_id", weight=0.4),
DatetimeCritic(
critic_field="updated_start_datetime", weight=0.2, tolerance=timedelta(minutes=15)
),
DatetimeCritic(
critic_field="updated_end_datetime",
weight=0.2,
tolerance=timedelta(minutes=15),
),
BinaryCritic(critic_field="send_updates", weight=0.2),
],
additional_messages=history_after_list_events,
)
@ -154,14 +177,16 @@ def calendar_eval_suite() -> EvalSuite:
# Cases for delete_event
suite.add_case(
name="Delete a calendar event",
user_message="I don't need to have focus time today. Please delete it from my calendar. Don't send any notifications.",
user_message=(
"I don't need to have focus time today. Please delete it from my calendar. Don't send any notifications."
),
expected_tool_calls=[
(
delete_event,
{
"event_id": "gr5g18lf88tfpp3vkareukkc7g",
"calendar_id": "primary",
"send_updates": "none",
"send_updates": SendUpdatesOptions.NONE,
},
)
],

View file

@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from arcade_google.tools.calendar import create_event, delete_event, list_events, update_event
from arcade_google.tools.models import Day, EventVisibility, SendUpdatesOptions, TimeSlot
from arcade_google.tools.models import EventVisibility, SendUpdatesOptions
from googleapiclient.errors import HttpError
from arcade.core.errors import ToolExecutionError
@ -24,7 +24,7 @@ async def test_create_event(mock_build, mock_context):
# Mock the calendar's time zone
mock_service.calendars().get().execute.return_value = {"timeZone": "America/Los_Angeles"}
# Case: HttpError
# Case: HttpError during event creation
mock_service.events().insert().execute.side_effect = HttpError(
resp=MagicMock(status=400),
content=b'{"error": {"message": "Invalid request"}}',
@ -34,10 +34,8 @@ async def test_create_event(mock_build, mock_context):
await create_event(
context=mock_context,
summary="Test Event",
start_date=Day.TODAY,
start_time=TimeSlot._1615,
end_date=Day.TODAY,
end_time=TimeSlot._1715,
start_datetime="2024-12-31T15:30:00",
end_datetime="2024-12-31T17:30:00",
description="Test Description",
location="Test Location",
visibility=EventVisibility.PRIVATE,
@ -53,34 +51,34 @@ async def test_list_events(mock_build, mock_context):
# Mock the calendar's time zone
mock_service.calendars().get().execute.return_value = {"timeZone": "America/Los_Angeles"}
# Case: min time is after max time. list_events tool should swap the times and still return the events
# Mock the events list response
mock_events_list_response = {
"items": [
{
"creator": {"email": "example@arcade-ai.com", "self": True},
"end": {"dateTime": "2024-09-27T01:00:00-07:00", "timeZone": "America/Los_Angeles"},
"eventType": "default",
"htmlLink": "https://www.google.com/calendar/event?eid=N2pmYjZ0ZmNnMGNydG5scmhkY2JvZWc4OGIgZXJpY0BhcmNhZGUtYWku",
"id": "7jfb6tfcg0crtnlrhdcboeg88b",
"htmlLink": "https://www.google.com/calendar/event?eid=event1",
"id": "event1",
"organizer": {"email": "example@arcade-ai.com", "self": True},
"start": {
"dateTime": "2024-09-27T00:00:00-07:00",
"timeZone": "America/Los_Angeles",
},
"summary": "teST",
"summary": "Event 1",
},
{
"creator": {"email": "example@arcade-ai.com", "self": True},
"end": {"dateTime": "2024-09-27T17:00:00-07:00", "timeZone": "America/Los_Angeles"},
"eventType": "default",
"htmlLink": "https://www.google.com/calendar/event?eid=MjZvYnRoc2xtMWMzbG5mdG10bzk4cDcxaGMgZXJpY0BhcmNhZGUtYWku",
"id": "26obthslm1c3lnftmto98p71hc",
"htmlLink": "https://www.google.com/calendar/event?eid=event2",
"id": "event2",
"organizer": {"email": "example@arcade-ai.com", "self": True},
"start": {
"dateTime": "2024-09-27T14:00:00-07:00",
"timeZone": "America/Los_Angeles",
},
"summary": "New Event",
"summary": "Event 2",
},
]
}
@ -89,16 +87,14 @@ async def test_list_events(mock_build, mock_context):
"events": mock_events_list_response["items"],
}
mock_service.events().list().execute.return_value = mock_events_list_response
message = await list_events(
response = await list_events(
context=mock_context,
min_day=Day.TODAY,
min_time_slot=TimeSlot._1615,
max_day=Day.TODAY,
max_time_slot=TimeSlot._1515,
min_end_datetime="2024-09-15T09:00:00",
max_start_datetime="2024-09-16T17:00:00",
)
assert message == expected_tool_response
assert response == expected_tool_response
# Case: HttpError
# Case: HttpError during events listing
mock_service.events().list().execute.side_effect = HttpError(
resp=MagicMock(status=400),
content=b'{"error": {"message": "Invalid request"}}',
@ -107,10 +103,8 @@ async def test_list_events(mock_build, mock_context):
with pytest.raises(ToolExecutionError):
await list_events(
context=mock_context,
min_day=Day.TODAY,
min_time_slot=TimeSlot._1615,
max_day=Day.TOMORROW,
max_time_slot=TimeSlot._1815,
min_end_datetime="2024-09-15T09:00:00",
max_start_datetime="2024-09-16T17:00:00",
)
@ -119,8 +113,10 @@ async def test_list_events(mock_build, mock_context):
async def test_update_event(mock_build, mock_context):
mock_service = MagicMock()
mock_build.return_value = mock_service
mock_service.events().update().execute.side_effect = HttpError(
resp=MagicMock(status=400),
# Mock retrieval of the event
mock_service.events().get().execute.side_effect = HttpError(
resp=MagicMock(status=404),
content=b'{"error": {"message": "Event not found"}}',
)
@ -128,10 +124,8 @@ async def test_update_event(mock_build, mock_context):
await update_event(
context=mock_context,
event_id="1234567890",
updated_start_day=Day.NEXT_FRIDAY,
updated_start_time=TimeSlot._0015,
updated_end_day=Day.NEXT_FRIDAY,
updated_end_time=TimeSlot._0115,
updated_start_datetime="2024-12-31T00:15:00",
updated_end_datetime="2024-12-31T01:15:00",
updated_summary="Updated Event",
updated_description="Updated Description",
updated_location="Updated Location",
@ -148,7 +142,7 @@ async def test_delete_event(mock_build, mock_context):
mock_service = MagicMock()
mock_build.return_value = mock_service
mock_service.events().delete().execute.side_effect = HttpError(
resp=MagicMock(status=400),
resp=MagicMock(status=404),
content=b'{"error": {"message": "Event not found"}}',
)
@ -156,4 +150,5 @@ async def test_delete_event(mock_build, mock_context):
await delete_event(
context=mock_context,
event_id="nonexistent_event",
send_updates=SendUpdatesOptions.ALL,
)

View file

@ -9,6 +9,8 @@ from arcade_math.tools.arithmetic import (
sum_range,
)
from arcade.sdk.error import ToolExecutionError
def test_add():
assert add(1, 2) == 3
@ -29,7 +31,7 @@ def test_multiply():
def test_divide():
assert divide(6, 3) == 2.0
assert divide(5, 2) == 2.5
with pytest.raises(ZeroDivisionError):
with pytest.raises(ToolExecutionError):
divide(1, 0)