Add `GET /v1/tools/list` (#100)
Add retrieving the list of available tool definitions that can be called. essential to working with frameworks like langchain/graph
This commit is contained in:
parent
20170d04e0
commit
6b716d6dde
22 changed files with 1159 additions and 586 deletions
11
Makefile
11
Makefile
|
|
@ -28,15 +28,22 @@ check: ## Run code quality tools.
|
|||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@cd arcade && poetry run pytest -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
@cd arcade && poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: test-toolkits
|
||||
test-toolkits: ## Iterate over all toolkits and run pytest on each one
|
||||
@echo "🚀 Testing code in toolkits: Running pytest"
|
||||
@for dir in toolkits/*/ ; do \
|
||||
(cd $$dir && poetry run pytest -v --cov --cov-config=pyproject.toml --cov-report=xml || exit 1); \
|
||||
(cd $$dir && poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml || exit 1); \
|
||||
done
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
@cd arcade && coverage report
|
||||
@echo "Generating coverage report"
|
||||
@cd arcade && coverage html
|
||||
|
||||
.PHONY: set-version
|
||||
set-version: ## Set the version in the pyproject.toml file
|
||||
@echo "🚀 Setting version in pyproject.toml"
|
||||
|
|
|
|||
263
arcade/arcade/cli/display.py
Normal file
263
arcade/arcade/cli/display.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from arcade.core.config_model import Config
|
||||
from arcade.core.schema import ToolDefinition
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arcade.sdk.eval.eval import EvaluationResult
|
||||
console = Console()
|
||||
|
||||
|
||||
def display_tools_table(tools: list[ToolDefinition]) -> None:
|
||||
"""
|
||||
Display a table of tools with their name, description, package, and version.
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold magenta")
|
||||
table.add_column("Name")
|
||||
table.add_column("Description")
|
||||
table.add_column("Package")
|
||||
table.add_column("Version")
|
||||
|
||||
for tool in sorted(tools, key=lambda x: x.toolkit.name):
|
||||
table.add_row(
|
||||
str(tool.get_fully_qualified_name()),
|
||||
tool.description.split("\n")[0] if tool.description else "",
|
||||
tool.toolkit.name,
|
||||
tool.toolkit.version,
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
|
||||
def display_tool_details(tool: ToolDefinition) -> None:
|
||||
"""
|
||||
Display detailed information about a specific tool using multiple panels.
|
||||
"""
|
||||
# Description Panel
|
||||
description_panel = Panel(
|
||||
tool.description or "No description available.",
|
||||
title=f"Tool: {tool.name}",
|
||||
border_style="cyan",
|
||||
)
|
||||
|
||||
# Inputs Panel
|
||||
inputs = tool.inputs.parameters
|
||||
if inputs:
|
||||
inputs_table = Table(show_header=True, header_style="bold green")
|
||||
inputs_table.add_column("Name", style="cyan")
|
||||
inputs_table.add_column("Type", style="magenta")
|
||||
inputs_table.add_column("Required", style="yellow")
|
||||
inputs_table.add_column("Description", style="white")
|
||||
inputs_table.add_column("Default", style="blue")
|
||||
for param in inputs:
|
||||
# Since InputParameter does not have a default field, we use "N/A"
|
||||
default_value = "N/A"
|
||||
if param.value_schema.enum:
|
||||
default_value = f"One of {param.value_schema.enum}"
|
||||
inputs_table.add_row(
|
||||
param.name,
|
||||
param.value_schema.val_type,
|
||||
str(param.required),
|
||||
param.description or "",
|
||||
default_value,
|
||||
)
|
||||
inputs_panel = Panel(
|
||||
inputs_table,
|
||||
title="Input Parameters",
|
||||
border_style="green",
|
||||
)
|
||||
else:
|
||||
inputs_panel = Panel(
|
||||
"No input parameters.",
|
||||
title="Input Parameters",
|
||||
border_style="green",
|
||||
)
|
||||
|
||||
# Output Panel
|
||||
output = tool.output
|
||||
if output:
|
||||
output_description = output.description or "No description available."
|
||||
output_types = ", ".join(output.available_modes)
|
||||
output_val_type = output.value_schema.val_type if output.value_schema else "N/A"
|
||||
output_details = Text.assemble(
|
||||
("Description: ", "bold"),
|
||||
(output_description, ""),
|
||||
"\n",
|
||||
("Available Modes: ", "bold"),
|
||||
(output_types, ""),
|
||||
"\n",
|
||||
("Value Type: ", "bold"),
|
||||
(output_val_type, ""),
|
||||
)
|
||||
output_panel = Panel(
|
||||
output_details,
|
||||
title="Expected Output",
|
||||
border_style="blue",
|
||||
)
|
||||
else:
|
||||
output_panel = Panel(
|
||||
"No output information available.",
|
||||
title="Expected Output",
|
||||
border_style="blue",
|
||||
)
|
||||
|
||||
# Combine all panels vertically
|
||||
console.print(description_panel)
|
||||
console.print(inputs_panel)
|
||||
console.print(output_panel)
|
||||
|
||||
|
||||
def display_tool_messages(tool_messages: list[dict]) -> None:
|
||||
for message in tool_messages:
|
||||
if message["role"] == "assistant":
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
console.print(
|
||||
f"[bright_black][bold]Called tool '{tool_call['function']['name']}'[/bold]\n[bold]Parameters:[/bold]{tool_call['function']['arguments']}[/bright_black]"
|
||||
)
|
||||
elif message["role"] == "tool":
|
||||
console.print(
|
||||
f"[bright_black][bold]'{message['name']}' tool returned:[/bold]{message['content']}[/bright_black]"
|
||||
)
|
||||
|
||||
|
||||
def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool = False) -> None:
|
||||
"""
|
||||
Display evaluation results in a format inspired by pytest's output.
|
||||
|
||||
Args:
|
||||
results: List of dictionaries containing evaluation results for each model.
|
||||
show_details: Whether to show detailed results for each case.
|
||||
"""
|
||||
total_passed = 0
|
||||
total_failed = 0
|
||||
total_warned = 0
|
||||
total_cases = 0
|
||||
|
||||
for eval_suite in results:
|
||||
for model_results in eval_suite:
|
||||
model = model_results.get("model", "Unknown Model")
|
||||
rubric = model_results.get("rubric", "Unknown Rubric")
|
||||
cases = model_results.get("cases", [])
|
||||
total_cases += len(cases)
|
||||
|
||||
console.print(f"[bold]Model:[/bold] [bold magenta]{model}[/bold magenta]")
|
||||
if show_details:
|
||||
console.print(f"[bold magenta]{rubric}[/bold magenta]")
|
||||
|
||||
for case in cases:
|
||||
evaluation = case["evaluation"]
|
||||
status = (
|
||||
"[green]PASSED[/green]"
|
||||
if evaluation.passed
|
||||
else "[yellow]WARNED[/yellow]"
|
||||
if evaluation.warning
|
||||
else "[red]FAILED[/red]"
|
||||
)
|
||||
if evaluation.passed:
|
||||
total_passed += 1
|
||||
elif evaluation.warning:
|
||||
total_warned += 1
|
||||
else:
|
||||
total_failed += 1
|
||||
|
||||
# Display one-line summary for each case with score as a percentage
|
||||
score_percentage = evaluation.score * 100
|
||||
console.print(f"{status} {case['name']} -- Score: {score_percentage:.2f}%")
|
||||
|
||||
if show_details:
|
||||
# Show detailed information for each case
|
||||
console.print(f"[bold]User Input:[/bold] {case['input']}\n")
|
||||
console.print("[bold]Details:[/bold]")
|
||||
console.print(_format_evaluation(evaluation))
|
||||
console.print("-" * 80)
|
||||
|
||||
# Summary
|
||||
summary = (
|
||||
f"[bold]Summary -- [/bold]Total: {total_cases} -- [green]Passed: {total_passed}[/green]"
|
||||
)
|
||||
if total_warned > 0:
|
||||
summary += f" -- [yellow]Warnings: {total_warned}[/yellow]"
|
||||
if total_failed > 0:
|
||||
summary += f" -- [red]Failed: {total_failed}[/red]"
|
||||
console.print(summary + "\n")
|
||||
|
||||
|
||||
def _format_evaluation(evaluation: "EvaluationResult") -> str:
|
||||
"""
|
||||
Format evaluation results with color-coded matches and scores.
|
||||
|
||||
Args:
|
||||
evaluation: An EvaluationResult object containing the evaluation results.
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the evaluation details.
|
||||
"""
|
||||
result_lines = []
|
||||
if evaluation.failure_reason:
|
||||
result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}")
|
||||
else:
|
||||
for critic_result in evaluation.results:
|
||||
match_color = "green" if critic_result["match"] else "red"
|
||||
field = critic_result["field"]
|
||||
score = critic_result["score"]
|
||||
weight = critic_result["weight"]
|
||||
expected = critic_result["expected"]
|
||||
actual = critic_result["actual"]
|
||||
|
||||
result_lines.append(
|
||||
f"[bold]{field}:[/bold] "
|
||||
f"[{match_color}]Match: {critic_result['match']}"
|
||||
f"\n Score: {score:.2f}/{weight:.2f}[/{match_color}]"
|
||||
f"\n Expected: {expected}"
|
||||
f"\n Actual: {actual}"
|
||||
)
|
||||
return "\n".join(result_lines)
|
||||
|
||||
|
||||
def display_arcade_chat_header(config: Config, stream: bool) -> None:
|
||||
chat_header = Text.assemble(
|
||||
"\n",
|
||||
(
|
||||
"=== Arcade AI Chat ===",
|
||||
"bold magenta underline",
|
||||
),
|
||||
"\n",
|
||||
"\n",
|
||||
"Chatting with Arcade Engine at ",
|
||||
(
|
||||
config.engine_url,
|
||||
"bold blue",
|
||||
),
|
||||
)
|
||||
if stream:
|
||||
chat_header.append(" (streaming)")
|
||||
console.print(chat_header)
|
||||
|
||||
|
||||
def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Display the configuration details as a table using Rich library.
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold magenta")
|
||||
table.add_column("Section")
|
||||
table.add_column("Name")
|
||||
table.add_column("Value")
|
||||
|
||||
for section_name in config.model_dump():
|
||||
section = getattr(config, section_name)
|
||||
if section:
|
||||
section = section.dict()
|
||||
first = True
|
||||
for name, value in section.items():
|
||||
if first:
|
||||
table.add_row(section_name, name, str(value))
|
||||
first = False
|
||||
else:
|
||||
table.add_row("", name, str(value))
|
||||
table.add_row("", "", "")
|
||||
|
||||
console.print(table)
|
||||
|
|
@ -11,27 +11,32 @@ import typer
|
|||
from openai import OpenAIError
|
||||
from rich.console import Console
|
||||
from rich.markup import escape
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
|
||||
from arcade.cli.display import (
|
||||
display_arcade_chat_header,
|
||||
display_config_as_table,
|
||||
display_eval_results,
|
||||
display_tool_details,
|
||||
display_tool_messages,
|
||||
display_tools_table,
|
||||
)
|
||||
from arcade.cli.launcher import start_servers
|
||||
from arcade.cli.utils import (
|
||||
OrderCommands,
|
||||
apply_config_overrides,
|
||||
create_cli_catalog,
|
||||
display_eval_results,
|
||||
display_tool_messages,
|
||||
get_config_with_overrides,
|
||||
get_eval_files,
|
||||
get_tools_from_engine,
|
||||
handle_chat_interaction,
|
||||
is_authorization_pending,
|
||||
load_eval_suites, # Import the new function
|
||||
load_eval_suites,
|
||||
log_engine_health,
|
||||
validate_and_get_config,
|
||||
wait_for_authorization_completion,
|
||||
)
|
||||
from arcade.client import Arcade
|
||||
from arcade.client.errors import EngineNotHealthyError, EngineOfflineError
|
||||
from arcade.core.config_model import Config
|
||||
|
||||
cli = typer.Typer(
|
||||
cls=OrderCommands,
|
||||
|
|
@ -44,27 +49,6 @@ cli = typer.Typer(
|
|||
console = Console()
|
||||
|
||||
|
||||
def _get_config_with_overrides(
|
||||
force_tls: bool,
|
||||
force_no_tls: bool,
|
||||
host_input: str | None = None,
|
||||
port_input: int | None = None,
|
||||
) -> Config:
|
||||
"""
|
||||
Get the config with CLI-specific optional overrides applied.
|
||||
"""
|
||||
config = validate_and_get_config()
|
||||
|
||||
if not force_tls and not force_no_tls:
|
||||
tls_input = None
|
||||
elif force_no_tls:
|
||||
tls_input = False
|
||||
else:
|
||||
tls_input = True
|
||||
apply_config_overrides(config, host_input, port_input, tls_input)
|
||||
return config
|
||||
|
||||
|
||||
@cli.command(help="Log in to Arcade Cloud", rich_help_panel="User")
|
||||
def login(
|
||||
host: str = typer.Option(
|
||||
|
|
@ -136,44 +120,75 @@ def new(
|
|||
|
||||
|
||||
@cli.command(
|
||||
help="Show the installed toolkits",
|
||||
help="Show the installed toolkits or details of a specific tool",
|
||||
rich_help_panel="Tool Development",
|
||||
)
|
||||
def show(
|
||||
toolkit: Optional[str] = typer.Option(
|
||||
None, "-t", "--toolkit", help="The toolkit to show the tools of"
|
||||
),
|
||||
tool: Optional[str] = typer.Option(
|
||||
None, "-T", "--tool", help="The specific tool to show details for"
|
||||
),
|
||||
host: Optional[str] = typer.Option(
|
||||
None,
|
||||
"-h",
|
||||
"--host",
|
||||
help="The Arcade Engine address to send chat requests to.",
|
||||
),
|
||||
port: Optional[int] = typer.Option(
|
||||
None,
|
||||
"-p",
|
||||
"--port",
|
||||
help="The port of the Arcade Engine.",
|
||||
),
|
||||
force_tls: bool = typer.Option(
|
||||
False,
|
||||
"--tls",
|
||||
help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.",
|
||||
),
|
||||
force_no_tls: bool = typer.Option(
|
||||
False,
|
||||
"--no-tls",
|
||||
help="Whether to disable TLS for the connection to the Arcade Engine.",
|
||||
),
|
||||
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
|
||||
) -> None:
|
||||
"""
|
||||
Show the available tools in an actor or toolkit
|
||||
Show the available toolkits or detailed information about a specific tool.
|
||||
"""
|
||||
|
||||
try:
|
||||
catalog = create_cli_catalog(toolkit=toolkit)
|
||||
if not host:
|
||||
catalog = create_cli_catalog(toolkit=toolkit)
|
||||
tools = [t.definition for t in list(catalog)]
|
||||
else:
|
||||
tools = get_tools_from_engine(host, port, force_tls, force_no_tls, toolkit)
|
||||
|
||||
# Create a table with Rich library
|
||||
table = Table(show_header=True, header_style="bold magenta")
|
||||
table.add_column("Name")
|
||||
table.add_column("Description")
|
||||
table.add_column("Package")
|
||||
table.add_column("Version")
|
||||
if tool:
|
||||
# Display detailed information for the specified tool
|
||||
tool_def = next(
|
||||
(
|
||||
t
|
||||
for t in tools
|
||||
if t.get_fully_qualified_name().name == tool
|
||||
or str(t.get_fully_qualified_name()) == tool
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool_def:
|
||||
console.print(f"❌ Tool '{tool}' not found.", style="bold red")
|
||||
typer.Exit(code=1)
|
||||
else:
|
||||
display_tool_details(tool_def)
|
||||
else:
|
||||
# Display the list of tools as a table
|
||||
display_tools_table(tools)
|
||||
|
||||
tool_names = catalog.get_tool_names()
|
||||
for tool_name in tool_names:
|
||||
tool = catalog.get_tool(tool_name)
|
||||
package = tool.meta.package if tool.meta.package else tool.meta.toolkit
|
||||
table.add_row(str(tool_name), tool.description, package, tool.version)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# used when debugging a broken package on import.
|
||||
# `arcade show` is the first command used after
|
||||
# a toolkit package is created.
|
||||
except Exception as e:
|
||||
if debug:
|
||||
raise
|
||||
error_message = f"❌ Failed to List tools: {escape(str(e))}"
|
||||
error_message = f"❌ Failed to list tools: {escape(str(e))}"
|
||||
console.print(error_message, style="bold red")
|
||||
|
||||
|
||||
|
|
@ -211,7 +226,7 @@ def chat(
|
|||
"""
|
||||
Chat with a language model.
|
||||
"""
|
||||
config = _get_config_with_overrides(force_tls, force_no_tls, host, port)
|
||||
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
|
||||
|
||||
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
|
||||
user_email = config.user.email if config.user else None
|
||||
|
|
@ -321,73 +336,6 @@ def config(
|
|||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
def display_arcade_chat_header(config: Config, stream: bool) -> None:
|
||||
chat_header = Text.assemble(
|
||||
"\n",
|
||||
(
|
||||
"=== Arcade AI Chat ===",
|
||||
"bold magenta underline",
|
||||
),
|
||||
"\n",
|
||||
"\n",
|
||||
"Chatting with Arcade Engine at ",
|
||||
(
|
||||
config.engine_url,
|
||||
"bold blue",
|
||||
),
|
||||
)
|
||||
if stream:
|
||||
chat_header.append(" (streaming)")
|
||||
console.print(chat_header)
|
||||
|
||||
|
||||
def log_engine_health(client: Arcade) -> None:
|
||||
try:
|
||||
client.health.check()
|
||||
|
||||
except EngineNotHealthyError as e:
|
||||
console.print(
|
||||
"[bold][yellow]⚠️ Warning: "
|
||||
+ str(e)
|
||||
+ " ("
|
||||
+ "[/yellow]"
|
||||
+ "[red]"
|
||||
+ str(e.status_code)
|
||||
+ "[/red]"
|
||||
+ "[yellow])[/yellow][/bold]"
|
||||
)
|
||||
except EngineOfflineError:
|
||||
console.print(
|
||||
"⚠️ Warning: Arcade Engine was unreachable. (Is it running?)",
|
||||
style="bold yellow",
|
||||
)
|
||||
|
||||
|
||||
def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Display the configuration details as a table using Rich library.
|
||||
"""
|
||||
table = Table(show_header=True, header_style="bold magenta")
|
||||
table.add_column("Section")
|
||||
table.add_column("Name")
|
||||
table.add_column("Value")
|
||||
|
||||
for section_name in config.model_dump():
|
||||
section = getattr(config, section_name)
|
||||
if section:
|
||||
section = section.dict()
|
||||
first = True
|
||||
for name, value in section.items():
|
||||
if first:
|
||||
table.add_row(section_name, name, str(value))
|
||||
first = False
|
||||
else:
|
||||
table.add_row("", name, str(value))
|
||||
table.add_row("", "", "")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@cli.command(help="Run tool calling evaluations", rich_help_panel="Tool Development")
|
||||
def evals(
|
||||
directory: str = typer.Argument(".", help="Directory containing evaluation files"),
|
||||
|
|
@ -428,7 +376,7 @@ def evals(
|
|||
Find all files starting with 'eval_' in the given directory,
|
||||
execute any functions decorated with @tool_eval, and display the results.
|
||||
"""
|
||||
config = _get_config_with_overrides(force_tls, force_no_tls, host, port)
|
||||
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
|
||||
|
||||
models_list = models.split(",") # Use 'models_list' to avoid shadowing
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import importlib.util
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import typer
|
||||
from openai.resources.chat.completions import ChatCompletionChunk, Stream
|
||||
|
|
@ -13,16 +13,14 @@ from typer.core import TyperGroup
|
|||
from typer.models import Context
|
||||
|
||||
from arcade.client.client import Arcade
|
||||
from arcade.client.errors import APITimeoutError
|
||||
from arcade.client.errors import APITimeoutError, EngineNotHealthyError, EngineOfflineError
|
||||
from arcade.client.schema import AuthResponse
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.config_model import Config
|
||||
from arcade.core.errors import ToolkitLoadError
|
||||
from arcade.core.schema import ToolDefinition
|
||||
from arcade.core.toolkit import Toolkit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arcade.sdk.eval.eval import EvaluationResult
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
|
|
@ -64,17 +62,37 @@ def create_cli_catalog(
|
|||
return catalog
|
||||
|
||||
|
||||
def display_tool_messages(tool_messages: list[dict]) -> None:
|
||||
for message in tool_messages:
|
||||
if message["role"] == "assistant":
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
console.print(
|
||||
f"[bright_black][bold]Called tool '{tool_call['function']['name']}'[/bold]\n[bold]Parameters:[/bold]{tool_call['function']['arguments']}[/bright_black]"
|
||||
)
|
||||
elif message["role"] == "tool":
|
||||
console.print(
|
||||
f"[bright_black][bold]'{message['name']}' tool returned:[/bold]{message['content']}[/bright_black]"
|
||||
)
|
||||
def get_config_with_overrides(
|
||||
force_tls: bool,
|
||||
force_no_tls: bool,
|
||||
host_input: str | None = None,
|
||||
port_input: int | None = None,
|
||||
) -> Config:
|
||||
"""
|
||||
Get the config with CLI-specific optional overrides applied.
|
||||
"""
|
||||
config = validate_and_get_config()
|
||||
|
||||
if not force_tls and not force_no_tls:
|
||||
tls_input = None
|
||||
elif force_no_tls:
|
||||
tls_input = False
|
||||
else:
|
||||
tls_input = True
|
||||
apply_config_overrides(config, host_input, port_input, tls_input)
|
||||
return config
|
||||
|
||||
|
||||
def get_tools_from_engine(
|
||||
host: str,
|
||||
port: int | None = None,
|
||||
force_tls: bool = False,
|
||||
force_no_tls: bool = False,
|
||||
toolkit: str | None = None,
|
||||
) -> list[ToolDefinition]:
|
||||
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
|
||||
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
|
||||
return client.tools.list_tools(toolkit=toolkit)
|
||||
|
||||
|
||||
def get_tool_messages(choice: dict) -> list[dict]:
|
||||
|
|
@ -199,96 +217,26 @@ def apply_config_overrides(
|
|||
config.engine.tls = tls_input
|
||||
|
||||
|
||||
def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool = False) -> None:
|
||||
"""
|
||||
Display evaluation results in a format inspired by pytest's output.
|
||||
def log_engine_health(client: Arcade) -> None:
|
||||
try:
|
||||
client.health.check()
|
||||
|
||||
Args:
|
||||
results: List of dictionaries containing evaluation results for each model.
|
||||
show_details: Whether to show detailed results for each case.
|
||||
"""
|
||||
total_passed = 0
|
||||
total_failed = 0
|
||||
total_warned = 0
|
||||
total_cases = 0
|
||||
|
||||
for eval_suite in results:
|
||||
for model_results in eval_suite:
|
||||
model = model_results.get("model", "Unknown Model")
|
||||
rubric = model_results.get("rubric", "Unknown Rubric")
|
||||
cases = model_results.get("cases", [])
|
||||
total_cases += len(cases)
|
||||
|
||||
console.print(f"[bold]Model:[/bold] [bold magenta]{model}[/bold magenta]")
|
||||
if show_details:
|
||||
console.print(f"[bold magenta]{rubric}[/bold magenta]")
|
||||
|
||||
for case in cases:
|
||||
evaluation = case["evaluation"]
|
||||
status = (
|
||||
"[green]PASSED[/green]"
|
||||
if evaluation.passed
|
||||
else "[yellow]WARNED[/yellow]"
|
||||
if evaluation.warning
|
||||
else "[red]FAILED[/red]"
|
||||
)
|
||||
if evaluation.passed:
|
||||
total_passed += 1
|
||||
elif evaluation.warning:
|
||||
total_warned += 1
|
||||
else:
|
||||
total_failed += 1
|
||||
|
||||
# Display one-line summary for each case
|
||||
console.print(f"{status} {case['name']} -- Score: {evaluation.score:.2f}")
|
||||
|
||||
if show_details:
|
||||
# Show detailed information for each case
|
||||
console.print(f"[bold]User Input:[/bold] {case['input']}\n")
|
||||
console.print("[bold]Details:[/bold]")
|
||||
console.print(_format_evaluation(evaluation))
|
||||
console.print("-" * 80)
|
||||
|
||||
# Summary
|
||||
summary = (
|
||||
f"[bold]Summary -- [/bold]Total: {total_cases} -- [green]Passed: {total_passed}[/green]"
|
||||
)
|
||||
if total_warned > 0:
|
||||
summary += f" -- [yellow]Warnings: {total_warned}[/yellow]"
|
||||
if total_failed > 0:
|
||||
summary += f" -- [red]Failed: {total_failed}[/red]"
|
||||
console.print(summary + "\n")
|
||||
|
||||
|
||||
def _format_evaluation(evaluation: "EvaluationResult") -> str:
|
||||
"""
|
||||
Format evaluation results with color-coded matches and scores.
|
||||
|
||||
Args:
|
||||
evaluation: An EvaluationResult object containing the evaluation results.
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the evaluation details.
|
||||
"""
|
||||
result_lines = []
|
||||
if evaluation.failure_reason:
|
||||
result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}")
|
||||
else:
|
||||
for critic_result in evaluation.results:
|
||||
match_color = "green" if critic_result["match"] else "red"
|
||||
field = critic_result["field"]
|
||||
score = critic_result["score"]
|
||||
weight = critic_result["weight"]
|
||||
expected = critic_result["expected"]
|
||||
actual = critic_result["actual"]
|
||||
result_lines.append(
|
||||
f"[bold]{field}:[/bold] "
|
||||
f"[{match_color}]Match: {critic_result['match']}, "
|
||||
f"Score: {score:.2f}/{weight:.2f}[/{match_color}]"
|
||||
f"\n Expected: {expected}"
|
||||
f"\n Actual: {actual}"
|
||||
)
|
||||
return "\n".join(result_lines)
|
||||
except EngineNotHealthyError as e:
|
||||
console.print(
|
||||
"[bold][yellow]⚠️ Warning: "
|
||||
+ str(e)
|
||||
+ " ("
|
||||
+ "[/yellow]"
|
||||
+ "[red]"
|
||||
+ str(e.status_code)
|
||||
+ "[/red]"
|
||||
+ "[yellow])[/yellow][/bold]"
|
||||
)
|
||||
except EngineOfflineError:
|
||||
console.print(
|
||||
"⚠️ Warning: Arcade Engine was unreachable. (Is it running?)",
|
||||
style="bold yellow",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -166,6 +166,17 @@ class ToolResource(BaseResource[ClientT]):
|
|||
)
|
||||
return AuthResponse(**data)
|
||||
|
||||
def list_tools(self, toolkit: str | None = None) -> list[ToolDefinition]:
|
||||
"""
|
||||
List the tools available for a given toolkit and provider.
|
||||
"""
|
||||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._resource_path}/list",
|
||||
params={"toolkit": toolkit},
|
||||
)
|
||||
return [ToolDefinition(**tool) for tool in data]
|
||||
|
||||
|
||||
class HealthResource(BaseResource[ClientT]):
|
||||
"""Health check resource."""
|
||||
|
|
@ -331,6 +342,17 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
|
|||
)
|
||||
return AuthResponse(**data)
|
||||
|
||||
async def list_tools(self, toolkit: str | None = None) -> list[ToolDefinition]:
|
||||
"""
|
||||
List the tools available for a given toolkit and provider.
|
||||
"""
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._resource_path}/list",
|
||||
params={"toolkit": toolkit},
|
||||
)
|
||||
return [ToolDefinition(**tool) for tool in data]
|
||||
|
||||
|
||||
class AsyncHealthResource(BaseResource[AsyncArcadeClient]):
|
||||
"""Asynchronous Health check resource."""
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from pydantic_core import PydanticUndefined
|
|||
|
||||
from arcade.core.errors import ToolDefinitionError
|
||||
from arcade.core.schema import (
|
||||
TOOL_NAME_SEPARATOR,
|
||||
FullyQualifiedName,
|
||||
InputParameter,
|
||||
OAuth2Requirement,
|
||||
|
|
@ -215,6 +216,45 @@ class ToolCatalog(BaseModel):
|
|||
return tool.definition
|
||||
raise ValueError(f"Tool {func} not found in the catalog.")
|
||||
|
||||
def get_tool_by_name(
|
||||
self, name: str, version: Optional[str] = None, separator: str = TOOL_NAME_SEPARATOR
|
||||
) -> MaterializedTool:
|
||||
"""
|
||||
Get a tool from the catalog by name, optionally including the toolkit name.
|
||||
|
||||
Args:
|
||||
name: The name of the tool, potentially including the toolkit name separated by the `separator`.
|
||||
version: The version of the toolkit. Defaults to None.
|
||||
separator: The separator between toolkit and tool names. Defaults to `TOOL_NAME_SEPARATOR`.
|
||||
|
||||
Returns:
|
||||
MaterializedTool: The matching tool from the catalog.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tool is not found in the catalog.
|
||||
"""
|
||||
if separator in name:
|
||||
toolkit_name, tool_name = name.split(separator, 1)
|
||||
fq_name = FullyQualifiedName(
|
||||
name=tool_name, toolkit_name=toolkit_name, toolkit_version=version
|
||||
)
|
||||
return self.get_tool(fq_name)
|
||||
else:
|
||||
# No toolkit name provided, search tools with matching tool name
|
||||
matching_tools = [
|
||||
tool
|
||||
for fq_name, tool in self._tools.items()
|
||||
if fq_name.name.lower() == name.lower()
|
||||
and (
|
||||
version is None
|
||||
or (fq_name.toolkit_version or "").lower() == (version or "").lower()
|
||||
)
|
||||
]
|
||||
if matching_tools:
|
||||
return matching_tools[0]
|
||||
|
||||
raise ValueError(f"Tool {name} not found in the catalog.")
|
||||
|
||||
def get_tool(self, name: FullyQualifiedName) -> MaterializedTool:
|
||||
"""
|
||||
Get a tool from the catalog by fully-qualified name and version.
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from .critic import BinaryCritic, NumericCritic, SimilarityCritic
|
||||
from .critic import BinaryCritic, DatetimeCritic, NumericCritic, SimilarityCritic
|
||||
from .eval import EvalRubric, EvalSuite, ExpectedToolCall, tool_eval
|
||||
|
||||
__all__ = [
|
||||
"BinaryCritic",
|
||||
"SimilarityCritic",
|
||||
"NumericCritic",
|
||||
"DatetimeCritic",
|
||||
"EvalRubric",
|
||||
"EvalSuite",
|
||||
"ExpectedToolCall",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import pytz
|
||||
from dateutil import parser
|
||||
|
||||
from arcade.sdk.error import WeightError
|
||||
|
||||
|
||||
|
|
@ -47,6 +51,17 @@ class BinaryCritic(Critic):
|
|||
Raises:
|
||||
TypeError: If the casting is not possible.
|
||||
"""
|
||||
# In case both are strings.
|
||||
if actual == "None":
|
||||
actual = None
|
||||
if expected == "None":
|
||||
expected = None
|
||||
if expected is None:
|
||||
# No need to cast; return actual as is
|
||||
return actual
|
||||
if actual is None:
|
||||
# No need to cast; return None
|
||||
return None
|
||||
expected_type = type(expected)
|
||||
try:
|
||||
return expected_type(actual)
|
||||
|
|
@ -60,14 +75,18 @@ class BinaryCritic(Critic):
|
|||
Evaluates whether the expected and actual values are exactly equal after casting.
|
||||
|
||||
Args:
|
||||
expected (Any): The expected value.
|
||||
actual (Any): The actual value to compare, cast to the type of expected.
|
||||
expected: The expected value.
|
||||
actual: The actual value to compare, cast to the type of expected.
|
||||
|
||||
Returns:
|
||||
dict[str, float | bool]: A dictionary containing the match status and score.
|
||||
dict: A dictionary containing the match status and score.
|
||||
"""
|
||||
# Cast actual to the type of expected
|
||||
actual_casted = self.cast_actual(expected, actual)
|
||||
try:
|
||||
actual_casted = self.cast_actual(expected, actual)
|
||||
# TODO log or something better here
|
||||
except TypeError:
|
||||
actual_casted = actual
|
||||
|
||||
match = expected == actual_casted
|
||||
return {"match": match, "score": self.weight if match else 0.0}
|
||||
|
|
@ -187,3 +206,68 @@ class SimilarityCritic(Critic):
|
|||
"match": similarity >= self.similarity_threshold,
|
||||
"score": min(similarity * self.weight, self.weight),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class DatetimeCritic(Critic):
|
||||
"""
|
||||
A critic that evaluates the closeness of datetime values within a specified tolerance.
|
||||
|
||||
Attributes:
|
||||
tolerance: Acceptable timedelta between expected and actual datetimes.
|
||||
max_difference: Maximum timedelta for a partial score.
|
||||
"""
|
||||
|
||||
critic_field: str
|
||||
weight: float
|
||||
tolerance: timedelta = timedelta(seconds=500)
|
||||
max_difference: timedelta = timedelta(hours=2)
|
||||
|
||||
def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]:
|
||||
"""Evaluates the closeness of datetime values within a specified tolerance."""
|
||||
|
||||
# Attempt to parse expected and actual datetime strings
|
||||
try:
|
||||
expected_dt = parser.parse(expected)
|
||||
actual_dt = parser.parse(actual)
|
||||
except (ValueError, TypeError):
|
||||
# If parsing fails, return score 0
|
||||
return {"match": False, "score": 0.0}
|
||||
|
||||
# Handle cases based on presence of tzinfo
|
||||
if expected_dt.tzinfo is None and actual_dt.tzinfo is None:
|
||||
# Both datetimes are naive, compare directly
|
||||
time_diff_seconds = abs((expected_dt - actual_dt).total_seconds())
|
||||
elif expected_dt.tzinfo is not None and actual_dt.tzinfo is not None:
|
||||
# Both datetimes have tzinfo, compare in UTC
|
||||
expected_utc = expected_dt.astimezone(pytz.utc)
|
||||
actual_utc = actual_dt.astimezone(pytz.utc)
|
||||
time_diff_seconds = abs((expected_utc - actual_utc).total_seconds())
|
||||
else:
|
||||
# One datetime has tzinfo and the other doesn't
|
||||
# Compare naive datetime with the other's naive equivalent
|
||||
if expected_dt.tzinfo is not None:
|
||||
expected_naive = expected_dt.replace(tzinfo=None)
|
||||
time_diff_seconds = abs((expected_naive - actual_dt).total_seconds())
|
||||
else:
|
||||
actual_naive = actual_dt.replace(tzinfo=None)
|
||||
time_diff_seconds = abs((expected_dt - actual_naive).total_seconds())
|
||||
|
||||
# Convert tolerances to seconds
|
||||
tolerance_seconds = self.tolerance.total_seconds()
|
||||
max_difference_seconds = self.max_difference.total_seconds()
|
||||
|
||||
if time_diff_seconds <= tolerance_seconds:
|
||||
# Full score if within tolerance
|
||||
return {"match": True, "score": self.weight}
|
||||
elif time_diff_seconds >= max_difference_seconds:
|
||||
# No score if beyond max_difference
|
||||
return {"match": False, "score": 0.0}
|
||||
else:
|
||||
# Partial score based on time difference
|
||||
ratio = 1 - (time_diff_seconds / max_difference_seconds)
|
||||
# Ensure ratio is not negative
|
||||
ratio = max(ratio, 0)
|
||||
score = self.weight * ratio
|
||||
return {"match": False, "score": score}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from arcade.core.config_model import Config
|
||||
from arcade.core.schema import FullyQualifiedName
|
||||
from arcade.core.schema import TOOL_NAME_SEPARATOR
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
|
@ -218,7 +219,10 @@ class EvalCase:
|
|||
expected_count = len(self.expected_tool_calls)
|
||||
return self.rubric.fail_on_tool_call_quantity and expected_count != actual_count
|
||||
|
||||
def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> EvaluationResult:
|
||||
def evaluate(
|
||||
self,
|
||||
actual_tool_calls: list[tuple[str, dict[str, Any]]],
|
||||
) -> EvaluationResult:
|
||||
"""
|
||||
Evaluate the actual tool calls against the expected tool calls and critics.
|
||||
|
||||
|
|
@ -229,13 +233,13 @@ class EvalCase:
|
|||
An EvaluationResult object containing the evaluation results.
|
||||
"""
|
||||
evaluation_result = EvaluationResult()
|
||||
actual_tools = [tool for tool, _ in actual_tool_calls]
|
||||
|
||||
actual_tools = [tool_name for tool_name, _ in actual_tool_calls]
|
||||
actual_count = len(actual_tool_calls)
|
||||
|
||||
if self.check_tool_call_quantity_failure(actual_count):
|
||||
evaluation_result.score = 0.0
|
||||
evaluation_result.passed = False
|
||||
evaluation_result.warning = False
|
||||
expected_count = len(self.expected_tool_calls)
|
||||
expected_tool_names = ", ".join(
|
||||
tool_call.name for tool_call in self.expected_tool_calls
|
||||
|
|
@ -246,35 +250,27 @@ class EvalCase:
|
|||
)
|
||||
return evaluation_result
|
||||
|
||||
# check if no tools should be called and none were called
|
||||
if not self.expected_tool_calls and not actual_tools:
|
||||
evaluation_result.score = 1.0
|
||||
evaluation_result.passed = True
|
||||
evaluation_result.warning = False
|
||||
return evaluation_result
|
||||
|
||||
if self.check_tool_selection_failure(actual_tools):
|
||||
evaluation_result.score = 0.0
|
||||
evaluation_result.passed = False
|
||||
evaluation_result.warning = False
|
||||
expected_tools = [tc.name for tc in self.expected_tool_calls]
|
||||
evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}"
|
||||
return evaluation_result
|
||||
|
||||
# if no critics for tool call arguments, then return
|
||||
# passing score as only tool selection and quantity is checked
|
||||
if not self.critics or len(self.critics) == 0:
|
||||
if not self.critics:
|
||||
evaluation_result.score = 1.0
|
||||
evaluation_result.passed = True
|
||||
evaluation_result.warning = False
|
||||
# TODO passing reason should be added
|
||||
return evaluation_result
|
||||
|
||||
# Create a cost matrix for the assignment problem
|
||||
cost_matrix = self._create_cost_matrix(actual_tool_calls)
|
||||
cost_matrix = self._create_cost_matrix(actual_tool_calls, self.expected_tool_calls)
|
||||
|
||||
# Use the Linear Sum Assignment (LSA) algorithm to find the optimal assignment
|
||||
# The algorithm maximizes the total score of the assignment
|
||||
# Use the Linear Sum Assignment algorithm to find the optimal assignment
|
||||
row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True)
|
||||
|
||||
total_score = 0.0
|
||||
|
|
@ -283,10 +279,11 @@ class EvalCase:
|
|||
for i, j in zip(row_ind, col_ind):
|
||||
if i < len(self.expected_tool_calls) and j < len(actual_tool_calls):
|
||||
expected = self.expected_tool_calls[i]
|
||||
actual_tool, actual_args = actual_tool_calls[j]
|
||||
actual_name, actual_args = actual_tool_calls[j]
|
||||
|
||||
# Tool selection
|
||||
tool_selection_score = evaluation_result.score_tool_selection(
|
||||
expected.name, actual_tool, self.rubric.tool_selection_weight
|
||||
expected.name, actual_name, self.rubric.tool_selection_weight
|
||||
)
|
||||
total_score += tool_selection_score
|
||||
total_weight += self.rubric.tool_selection_weight
|
||||
|
|
@ -295,32 +292,35 @@ class EvalCase:
|
|||
for critic in self.critics:
|
||||
expected_value = expected.args.get(critic.critic_field)
|
||||
actual_value = actual_args.get(critic.critic_field)
|
||||
if expected_value is not None and actual_value is not None:
|
||||
try:
|
||||
result = critic.evaluate(expected_value, actual_value)
|
||||
total_score += result["score"]
|
||||
total_weight += critic.weight
|
||||
evaluation_result.add(
|
||||
critic.critic_field,
|
||||
result,
|
||||
critic.weight,
|
||||
expected_value,
|
||||
actual_value,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Critic evaluation failed for field '{critic.critic_field}': {e}"
|
||||
)
|
||||
# Depending on requirements, you might want to continue or handle differently
|
||||
continue
|
||||
|
||||
# Compute the final score using the method from EvaluationResult
|
||||
try:
|
||||
result = critic.evaluate(expected_value, actual_value)
|
||||
total_score += result["score"]
|
||||
total_weight += critic.weight
|
||||
evaluation_result.add(
|
||||
critic.critic_field,
|
||||
result,
|
||||
critic.weight,
|
||||
expected_value,
|
||||
actual_value,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: log or console
|
||||
print(f"Critic evaluation failed for field '{critic.critic_field}': {e}")
|
||||
evaluation_result.add(
|
||||
critic.critic_field,
|
||||
{"match": False, "score": 0.0},
|
||||
critic.weight,
|
||||
expected_value,
|
||||
actual_value,
|
||||
)
|
||||
continue
|
||||
|
||||
# Compute the final score
|
||||
evaluation_result.compute_final_score(total_weight)
|
||||
|
||||
# Set the pass/fail status based on the fail_threshold
|
||||
# Set pass/fail and warning status
|
||||
evaluation_result.passed = evaluation_result.score >= self.rubric.fail_threshold
|
||||
|
||||
# Set the warning status based on the warn_threshold
|
||||
evaluation_result.warning = (
|
||||
not evaluation_result.passed and evaluation_result.score >= self.rubric.warn_threshold
|
||||
)
|
||||
|
|
@ -328,103 +328,52 @@ class EvalCase:
|
|||
return evaluation_result
|
||||
|
||||
def _create_cost_matrix(
|
||||
self, actual_tool_calls: list[tuple[str, dict[str, Any]]]
|
||||
self,
|
||||
actual_tool_calls: list[tuple[str, dict[str, Any]]],
|
||||
expected_tool_calls: list[ExpectedToolCall],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Create a cost matrix for the Hungarian algorithm.
|
||||
|
||||
This method computes the score for each possible pairing of expected and actual tool calls.
|
||||
The resulting matrix is used by the Hungarian algorithm to find the optimal assignment.
|
||||
Create a cost matrix for the assignment problem.
|
||||
|
||||
Args:
|
||||
actual_tool_calls: A list of tuples containing the actual tool calls and their arguments.
|
||||
actual_tool_calls: A list of tuples of actual tool calls.
|
||||
expected_tool_calls: A list of ExpectedToolCall instances.
|
||||
|
||||
Returns:
|
||||
A numpy array representing the cost matrix.
|
||||
"""
|
||||
num_expected = len(self.expected_tool_calls)
|
||||
num_expected = len(expected_tool_calls)
|
||||
num_actual = len(actual_tool_calls)
|
||||
n = max(num_expected, num_actual)
|
||||
|
||||
# Initialize a score matrix with zeros
|
||||
score_matrix = np.zeros((n, n))
|
||||
cost_matrix = np.zeros((n, n))
|
||||
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
if i < num_expected and j < num_actual:
|
||||
expected = self.expected_tool_calls[i]
|
||||
expected_tool = expected.name
|
||||
expected_args = expected.args
|
||||
actual_tool, actual_args = actual_tool_calls[j]
|
||||
expected = expected_tool_calls[i]
|
||||
actual_name, actual_args = actual_tool_calls[j]
|
||||
score = 0.0
|
||||
|
||||
# Tool selection
|
||||
if compare_tool_name(expected_tool, actual_tool):
|
||||
if compare_tool_name(expected.name, actual_name):
|
||||
score += self.rubric.tool_selection_weight
|
||||
|
||||
# Critics evaluation
|
||||
if self.critics:
|
||||
for critic in self.critics:
|
||||
expected_value = expected_args.get(critic.critic_field)
|
||||
actual_value = actual_args.get(critic.critic_field)
|
||||
if expected_value is not None and actual_value is not None:
|
||||
try:
|
||||
result = critic.evaluate(expected_value, actual_value)
|
||||
score += result.get("score", 0.0)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Critic evaluation failed for field '{critic.critic_field}': {e}"
|
||||
)
|
||||
continue
|
||||
for critic in self.critics: # type: ignore[union-attr]
|
||||
expected_value = expected.args.get(critic.critic_field)
|
||||
actual_value = actual_args.get(critic.critic_field)
|
||||
if expected_value is not None and actual_value is not None:
|
||||
try:
|
||||
result = critic.evaluate(expected_value, actual_value)
|
||||
score += result.get("score", 0.0)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Critic evaluation failed for field '{critic.critic_field}': {e}"
|
||||
)
|
||||
cost_matrix[i, j] = score
|
||||
|
||||
score_matrix[i, j] = score
|
||||
else:
|
||||
# Assign a score of 0 for dummy assignments
|
||||
score_matrix[i, j] = 0.0
|
||||
|
||||
return score_matrix
|
||||
|
||||
async def run(
|
||||
self, client: AsyncArcade, model: str, tool_names: list[FullyQualifiedName]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run the evaluation case asynchronously.
|
||||
|
||||
Args:
|
||||
client: The AsyncArcade client instance.
|
||||
model: The model to evaluate.
|
||||
tool_names: The list of tool names to use for the evaluation.
|
||||
Returns:
|
||||
A dictionary containing the evaluation result for the case.
|
||||
"""
|
||||
messages = [{"role": "system", "content": self.system_message}]
|
||||
messages.extend(list(self.additional_messages))
|
||||
messages.append({"role": "user", "content": self.user_message})
|
||||
|
||||
response = await client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=messages,
|
||||
tool_choice="auto",
|
||||
tools=(str(name) for name in tool_names),
|
||||
user="eval_user",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
predicted_args = get_tool_args(response)
|
||||
|
||||
evaluation = self.evaluate(predicted_args)
|
||||
|
||||
result = {
|
||||
"name": self.name,
|
||||
"input": self.user_message,
|
||||
"expected_tool_calls": [
|
||||
{"name": tc.name, "args": tc.args} for tc in self.expected_tool_calls
|
||||
],
|
||||
"predicted_tool_calls": [{"name": tool, "args": args} for tool, args in predicted_args],
|
||||
"evaluation": evaluation,
|
||||
}
|
||||
|
||||
return result
|
||||
return cost_matrix
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -467,19 +416,19 @@ class EvalSuite:
|
|||
Args:
|
||||
name: The name of the evaluation case.
|
||||
user_message: The user's input message.
|
||||
system_message: The system message to be sent to the AI model.
|
||||
expected_tool_calls: A list of expected tool calls.
|
||||
expected_tool_calls: A list of expected tool calls as tuples of (function, args).
|
||||
critics: List of critics to evaluate the tool arguments.
|
||||
system_message: The system message to be used.
|
||||
rubric: The evaluation rubric for this case.
|
||||
additional_messages: Optional list of additional messages for context.
|
||||
"""
|
||||
expected = [
|
||||
ExpectedToolCall(
|
||||
name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()),
|
||||
args=args,
|
||||
)
|
||||
for func, args in expected_tool_calls
|
||||
]
|
||||
expected = []
|
||||
for func, args in expected_tool_calls:
|
||||
# Fill in default arguments here
|
||||
args_with_defaults = self._fill_args_with_defaults(func, args)
|
||||
tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name())
|
||||
expected.append(ExpectedToolCall(name=tool_name, args=args_with_defaults))
|
||||
|
||||
case = EvalCase(
|
||||
name=name,
|
||||
system_message=system_message or self.system_message,
|
||||
|
|
@ -491,6 +440,30 @@ class EvalSuite:
|
|||
)
|
||||
self.cases.append(case)
|
||||
|
||||
def _fill_args_with_defaults(
|
||||
self, func: Callable, provided_args: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fill in default arguments for a tool function.
|
||||
|
||||
Args:
|
||||
func: The tool function.
|
||||
provided_args: The provided arguments.
|
||||
|
||||
Returns:
|
||||
A dictionary with default arguments filled in.
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
args_with_defaults = {}
|
||||
for param in sig.parameters.values():
|
||||
if param.name in provided_args:
|
||||
args_with_defaults[param.name] = provided_args[param.name]
|
||||
elif param.default is not inspect.Parameter.empty:
|
||||
args_with_defaults[param.name] = param.default
|
||||
else:
|
||||
args_with_defaults[param.name] = None # or raise an error
|
||||
return args_with_defaults
|
||||
|
||||
def extend_case(
|
||||
self,
|
||||
name: str,
|
||||
|
|
@ -528,13 +501,12 @@ class EvalSuite:
|
|||
|
||||
expected = last_case.expected_tool_calls
|
||||
if expected_tool_calls:
|
||||
expected = [
|
||||
ExpectedToolCall(
|
||||
name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()),
|
||||
args=args,
|
||||
)
|
||||
for func, args in expected_tool_calls
|
||||
]
|
||||
expected = []
|
||||
for func, args in expected_tool_calls:
|
||||
# Fill in default arguments here
|
||||
args_with_defaults = self._fill_args_with_defaults(func, args)
|
||||
tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name())
|
||||
expected.append(ExpectedToolCall(name=tool_name, args=args_with_defaults))
|
||||
|
||||
# Create a new case, copying from the last one and updating fields
|
||||
new_case = EvalCase(
|
||||
|
|
@ -550,9 +522,10 @@ class EvalSuite:
|
|||
|
||||
async def run(self, client: AsyncArcade, model: str) -> dict[str, Any]:
|
||||
"""
|
||||
Run the evaluation suite asynchronously.
|
||||
Run the evaluation suite.
|
||||
|
||||
Args:
|
||||
client: The AsyncArcade client instance.
|
||||
model: The model to evaluate.
|
||||
|
||||
Returns:
|
||||
|
|
@ -565,7 +538,48 @@ class EvalSuite:
|
|||
|
||||
async def sem_task(case: EvalCase) -> dict[str, Any]:
|
||||
async with semaphore:
|
||||
return await case.run(client, model, tool_names)
|
||||
# Prepare messages
|
||||
messages = [{"role": "system", "content": case.system_message}]
|
||||
messages.extend(case.additional_messages)
|
||||
messages.append({"role": "user", "content": case.user_message})
|
||||
|
||||
# Get the model response
|
||||
response = await client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=messages,
|
||||
tool_choice="auto",
|
||||
tools=(str(name) for name in tool_names),
|
||||
user="eval_user",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Extract and fill default arguments for actual tool calls
|
||||
predicted_args = get_tool_args(response)
|
||||
filled_actual_tool_calls = []
|
||||
for tool_name, args in predicted_args:
|
||||
tool = self.catalog.get_tool_by_name(tool_name)
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool '{tool_name}' not found in catalog.")
|
||||
func = tool.tool
|
||||
args_with_defaults = self._fill_args_with_defaults(func, args)
|
||||
filled_actual_tool_calls.append((tool_name, args_with_defaults))
|
||||
|
||||
# Evaluate the case
|
||||
evaluation = case.evaluate(filled_actual_tool_calls)
|
||||
|
||||
# Prepare the result
|
||||
result = {
|
||||
"name": case.name,
|
||||
"input": case.user_message,
|
||||
"expected_tool_calls": [
|
||||
{"name": tc.name, "args": tc.args} for tc in case.expected_tool_calls
|
||||
],
|
||||
"predicted_tool_calls": [
|
||||
{"name": name, "args": args} for name, args in filled_actual_tool_calls
|
||||
],
|
||||
"evaluation": evaluation,
|
||||
}
|
||||
return result
|
||||
|
||||
tasks = [sem_task(case) for case in self.cases]
|
||||
case_results = await asyncio.gather(*tasks)
|
||||
|
|
@ -589,7 +603,7 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
|
|||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_args_list.append((
|
||||
tool_call.function.name,
|
||||
normalize_name(tool_call.function.name),
|
||||
json.loads(tool_call.function.arguments),
|
||||
))
|
||||
return tool_args_list
|
||||
|
|
@ -597,17 +611,31 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
|
|||
|
||||
def compare_tool_name(expected: str, actual: str) -> bool:
|
||||
"""
|
||||
Compare the tool name without penalizing for mismatch in separators
|
||||
between module names and tool names ex. '-' vs '_' vs '.' vs ' '
|
||||
"""
|
||||
# TODO optimize this
|
||||
# Remove all separators from both names
|
||||
separators = "-_."
|
||||
expected_clean = "".join(char for char in expected if char not in separators)
|
||||
actual_clean = "".join(char for char in actual if char not in separators)
|
||||
Compare the tool names by replacing all separators with the TOOL_NAME_SEPARATOR
|
||||
and comparing the normalized names.
|
||||
|
||||
# Compare the cleaned names
|
||||
return expected_clean.lower() == actual_clean.lower()
|
||||
Converts names like 'Google_ListEmails' to 'Google.ListEmails' if
|
||||
TOOL_NAME_SEPARATOR is '.'.
|
||||
|
||||
Args:
|
||||
expected: The expected tool name.
|
||||
actual: The actual tool name.
|
||||
|
||||
Returns:
|
||||
True if the normalized tool names match, False otherwise.
|
||||
"""
|
||||
separators = "-_."
|
||||
expected_normalized = normalize_name(expected, separators)
|
||||
actual_normalized = normalize_name(actual, separators)
|
||||
|
||||
return expected_normalized.lower() == actual_normalized.lower()
|
||||
|
||||
|
||||
def normalize_name(name: str, separators: str = "-_.") -> str:
|
||||
for sep in separators:
|
||||
if sep != TOOL_NAME_SEPARATOR:
|
||||
name = name.replace(sep, TOOL_NAME_SEPARATOR)
|
||||
return name
|
||||
|
||||
|
||||
def tool_eval() -> Callable[[Callable], Callable]:
|
||||
|
|
|
|||
|
|
@ -7,3 +7,5 @@ coverage:
|
|||
default:
|
||||
target: 90%
|
||||
threshold: 0.5%
|
||||
exclude:
|
||||
- arcade/cli/**
|
||||
|
|
|
|||
|
|
@ -30,10 +30,12 @@ uvicorn = {version = "^0.30.0", optional = true}
|
|||
scipy = {version = "^1.14.0", optional = true}
|
||||
numpy = {version = "^2.0.0", optional = true}
|
||||
scikit-learn = {version = "^1.5.0", optional = true}
|
||||
pytz = {version = "^2024.1", optional = true}
|
||||
python-dateutil = {version = "^2.8.2", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
fastapi = ["fastapi", "uvicorn", "opentelemetry-instrumentation-fastapi", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-common"]
|
||||
evals = ["scipy", "numpy", "scikit-learn"]
|
||||
evals = ["scipy", "numpy", "scikit-learn", "pytz", "python-dateutil"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.1.2"
|
||||
|
|
@ -43,6 +45,8 @@ pre-commit = "^3.4.0"
|
|||
tox = "^4.11.1"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
types-toml = "^0.10.8"
|
||||
types-pytz = "^2024.1"
|
||||
types-python-dateutil = "^2.8.2"
|
||||
poetry-plugin-export = "^1.7.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
|
@ -63,9 +67,11 @@ ignore_missing_imports = "True"
|
|||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source = ["arcade"]
|
||||
omit = ["arcade/cli/*"]
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
|
|
|||
|
|
@ -373,3 +373,24 @@ async def test_async_arcade_health_check_raises_error(
|
|||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
with pytest.raises(EngineNotHealthyError):
|
||||
await test_async_client.health.check()
|
||||
|
||||
|
||||
def test_arcade_tool_list_tools(test_sync_client, mock_response, monkeypatch):
|
||||
"""Test Arcade.tools.list_tools method."""
|
||||
data = [TOOL_DEFINITION_DATA]
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: data)
|
||||
tool_definitions = test_sync_client.tools.list_tools(toolkit="TestToolkit")
|
||||
assert tool_definitions == [ToolDefinition(**TOOL_DEFINITION_DATA)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_tool_list_tools(test_async_client, mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.tools.list_tools method."""
|
||||
data = [TOOL_DEFINITION_DATA]
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return data
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
tool_definitions = await test_async_client.tools.list_tools(toolkit="TestToolkit")
|
||||
assert tool_definitions == [ToolDefinition(**TOOL_DEFINITION_DATA)]
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ def test_get_tool(toolkit_version: str | None, expected_tool):
|
|||
name="SampleTool", toolkit_name="SampleToolkit", toolkit_version=toolkit_version
|
||||
)
|
||||
tool = catalog.get_tool(fq_name)
|
||||
|
||||
assert tool.tool == expected_tool
|
||||
|
||||
|
||||
|
|
@ -102,3 +101,38 @@ def test_add_toolkit_type_error():
|
|||
assert "Type error encountered while adding tool invalid_tool from mock_module" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_by_name():
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(sample_tool, "sample_toolkit")
|
||||
|
||||
tool = catalog.get_tool_by_name("SampleToolkit.SampleTool")
|
||||
assert tool.tool == sample_tool
|
||||
assert tool.name == "SampleTool"
|
||||
assert tool.meta.toolkit == "sample_toolkit"
|
||||
assert tool.version is None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
catalog.get_tool_by_name("nonexistent_toolkit.SampleTool")
|
||||
|
||||
|
||||
def test_get_tool_by_name_with_version():
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(sample_tool, "sample_toolkit")
|
||||
|
||||
tool = catalog.get_tool_by_name("SampleToolkit.SampleTool")
|
||||
assert tool.tool == sample_tool
|
||||
assert tool.name == "SampleTool"
|
||||
assert tool.meta.toolkit == "sample_toolkit"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
catalog.get_tool_by_name("SampleToolkit.SampleTool", version="2.0.0")
|
||||
|
||||
|
||||
def test_get_tool_by_name_with_invalid_version():
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(sample_tool, "SampleToolkit")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
catalog.get_tool_by_name("SampleToolkit.SampleTool", version="2.0.0")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
from dateutil import parser
|
||||
|
||||
from arcade.sdk.error import WeightError
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
DatetimeCritic,
|
||||
EvalRubric,
|
||||
ExpectedToolCall,
|
||||
NumericCritic,
|
||||
|
|
@ -255,19 +260,13 @@ def test_eval_case_multiple_critics():
|
|||
# Test EvalCase with missing expected and actual values in args
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_args, actual_args, expected_score",
|
||||
[
|
||||
({"param": "value"}, {}, 1.0), # Missing actual value
|
||||
({}, {"param": "value"}, 1.0), # Missing expected value
|
||||
({"param": "value"}, {"param": "value"}, 2.0), # Both values present
|
||||
],
|
||||
)
|
||||
def test_eval_case_missing_values(expected_args, actual_args, expected_score):
|
||||
def test_eval_case_with_none_values():
|
||||
"""
|
||||
Test that when either expected or actual values are missing for a critic,
|
||||
the critic evaluation is skipped, and the total score is computed accordingly.
|
||||
Test that when expected or actual values are None, the critic evaluates them appropriately.
|
||||
"""
|
||||
expected_args = {"param": None}
|
||||
actual_args = {"param": None}
|
||||
|
||||
expected_tool_calls = [ExpectedToolCall(name="ToolA", args=expected_args)]
|
||||
actual_tool_calls = [("ToolA", actual_args)]
|
||||
|
||||
|
|
@ -284,15 +283,8 @@ def test_eval_case_missing_values(expected_args, actual_args, expected_score):
|
|||
|
||||
result = case.evaluate(actual_tool_calls)
|
||||
|
||||
# If critic is skipped, only tool selection score is counted
|
||||
# Otherwise, tool selection + critic score
|
||||
total_weight = 1.0 # At least tool selection weight
|
||||
if "param" in expected_args and "param" in actual_args:
|
||||
total_weight += 1.0 # Critic weight
|
||||
|
||||
expected_total_score = expected_score / total_weight
|
||||
|
||||
assert result.score == expected_total_score
|
||||
# Both values are None, so the critic should return a match
|
||||
assert result.score == 2.0 / 2.0 # Full score (tool selection + critic score)
|
||||
|
||||
|
||||
# Test that WeightError is raised for invalid critic weights
|
||||
|
|
@ -340,3 +332,136 @@ def test_similarity_critic_unsupported_metric():
|
|||
"""
|
||||
with pytest.raises(ValueError):
|
||||
SimilarityCritic(critic_field="text", weight=1.0, metric="unsupported_metric")
|
||||
|
||||
|
||||
# Test DatetimeCritic
|
||||
|
||||
|
||||
# Parameterized tests for DatetimeCritic with various datetime formats and default timezones
|
||||
@pytest.mark.parametrize(
|
||||
"critic_params, expected, actual, expected_match, expected_score",
|
||||
[
|
||||
# Test with time component and timezone
|
||||
(
|
||||
{"critic_field": "start_datetime", "weight": 1.0},
|
||||
"2024-09-26T12:00:00-07:00",
|
||||
"2024-09-26T12:00:00-07:00",
|
||||
True,
|
||||
1.0,
|
||||
),
|
||||
# Test without time component (dates only)
|
||||
(
|
||||
{"critic_field": "start_datetime", "weight": 1.0},
|
||||
"2024-09-26",
|
||||
"2024-09-26",
|
||||
True,
|
||||
1.0,
|
||||
),
|
||||
# Test with and without timezone (assumes UTC)
|
||||
(
|
||||
{"critic_field": "start_datetime", "weight": 1.0},
|
||||
"2024-09-26T12:00:00Z",
|
||||
"2024-09-26T12:00:00",
|
||||
True,
|
||||
1.0,
|
||||
),
|
||||
# Test naive datetimes
|
||||
(
|
||||
{"critic_field": "start_datetime", "weight": 1.0},
|
||||
"2024-09-26T12:00:00",
|
||||
"2024-09-26T12:00:00",
|
||||
True,
|
||||
1.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_datetime_critic_basic(critic_params, expected, actual, expected_match, expected_score):
|
||||
"""
|
||||
Test DatetimeCritic with various datetime formats and default timezones.
|
||||
"""
|
||||
critic = DatetimeCritic(**critic_params)
|
||||
result = critic.evaluate(expected, actual)
|
||||
assert result["match"] == expected_match
|
||||
assert result["score"] == expected_score
|
||||
|
||||
|
||||
# Parameterized tests for DatetimeCritic's handling of tolerances and max differences
|
||||
@pytest.mark.parametrize(
|
||||
"critic_params, expected, actual, expected_match, expected_score_func",
|
||||
[
|
||||
# Test time difference within tolerance
|
||||
(
|
||||
{"critic_field": "start_datetime", "weight": 1.0, "tolerance": timedelta(seconds=60)},
|
||||
"2024-09-26T12:00:00",
|
||||
"2024-09-26T12:00:30",
|
||||
True,
|
||||
lambda critic: critic.weight,
|
||||
),
|
||||
# Test time difference outside tolerance but within max_difference
|
||||
(
|
||||
{
|
||||
"critic_field": "start_datetime",
|
||||
"weight": 1.0,
|
||||
"tolerance": timedelta(seconds=60),
|
||||
"max_difference": timedelta(minutes=5),
|
||||
},
|
||||
"2024-09-26T12:00:00",
|
||||
"2024-09-26T12:04:00",
|
||||
False,
|
||||
lambda critic: critic.weight * (1 - (240 / 300)),
|
||||
),
|
||||
# Test time difference exceeds max_difference
|
||||
(
|
||||
{
|
||||
"critic_field": "start_datetime",
|
||||
"weight": 1.0,
|
||||
"max_difference": timedelta(minutes=5),
|
||||
},
|
||||
"2024-09-26T12:00:00",
|
||||
"2024-09-26T12:10:00",
|
||||
False,
|
||||
lambda critic: 0.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_datetime_critic_tolerances(
|
||||
critic_params, expected, actual, expected_match, expected_score_func
|
||||
):
|
||||
"""
|
||||
Test DatetimeCritic's handling of tolerances and max differences.
|
||||
"""
|
||||
critic = DatetimeCritic(**critic_params)
|
||||
result = critic.evaluate(expected, actual)
|
||||
assert result["match"] == expected_match
|
||||
expected_score = expected_score_func(critic)
|
||||
assert pytest.approx(result["score"], abs=1e-6) == expected_score
|
||||
|
||||
|
||||
def test_datetime_critic_naive_and_timezone_aware():
|
||||
"""
|
||||
Test DatetimeCritic when comparing naive and timezone-aware datetimes.
|
||||
"""
|
||||
critic = DatetimeCritic(critic_field="start_datetime", weight=1.0)
|
||||
expected = "2024-09-26T12:00:00Z"
|
||||
actual = "2024-09-26T07:00:00"
|
||||
result = critic.evaluate(expected, actual)
|
||||
assert result["match"] is False
|
||||
|
||||
# Compute expected score based on time difference
|
||||
expected_dt = parser.parse(expected)
|
||||
actual_dt = parser.parse(actual)
|
||||
if actual_dt.tzinfo is None:
|
||||
actual_dt = pytz.utc.localize(actual_dt)
|
||||
if expected_dt.tzinfo is None:
|
||||
expected_dt = pytz.utc.localize(expected_dt)
|
||||
|
||||
time_diff_seconds = abs((expected_dt - actual_dt).total_seconds())
|
||||
if time_diff_seconds <= critic.tolerance.total_seconds():
|
||||
expected_score = critic.weight
|
||||
elif time_diff_seconds >= critic.max_difference.total_seconds():
|
||||
expected_score = 0.0
|
||||
else:
|
||||
ratio = 1 - (time_diff_seconds / critic.max_difference.total_seconds())
|
||||
expected_score = critic.weight * ratio
|
||||
|
||||
assert pytest.approx(result["score"], abs=1e-6) == expected_score
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ async def test_error_responses(
|
|||
if status_code == 422:
|
||||
await list_org_repositories(mock_context, "org", repo_type=RepoType.ALL)
|
||||
elif status_code == 301:
|
||||
await count_stargazers("owner", "repo")
|
||||
await count_stargazers(mock_context, "owner", "repo")
|
||||
elif status_code == 404:
|
||||
await list_org_repositories(mock_context, "non_existent_org")
|
||||
elif status_code == 503:
|
||||
|
|
@ -66,8 +66,8 @@ async def test_list_repository_activities_invalid_cursor(mock_context, mock_clie
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_stargazers_success(mock_client):
|
||||
async def test_count_stargazers_success(mock_context, mock_client):
|
||||
mock_client.get.return_value = Response(200, json={"stargazers_count": 42})
|
||||
|
||||
result = await count_stargazers("owner", "repo")
|
||||
assert result == "The repository owner/repo has 42 stargazers."
|
||||
result = await count_stargazers(mock_context, "owner", "repo")
|
||||
assert result == 42
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
import arcade_github
|
||||
from arcade_github.tools.models import DiffSide, ReviewCommentSubjectType # Add these imports
|
||||
from arcade_github.tools.models import (
|
||||
DiffSide,
|
||||
ReviewCommentSubjectType,
|
||||
SortDirection,
|
||||
)
|
||||
from arcade_github.tools.pull_requests import (
|
||||
create_reply_for_review_comment,
|
||||
create_review_comment, # Add this import
|
||||
create_review_comment,
|
||||
get_pull_request,
|
||||
list_pull_request_commits,
|
||||
list_pull_requests,
|
||||
|
|
@ -169,7 +173,7 @@ def github_pull_requests_eval_suite() -> EvalSuite:
|
|||
"repo": "test",
|
||||
"pull_number": 72,
|
||||
"sort": "updated",
|
||||
"direction": "asc",
|
||||
"direction": SortDirection.ASC,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import arcade_github
|
||||
from arcade_github.tools.models import SortDirection
|
||||
from arcade_github.tools.repositories import (
|
||||
count_stargazers,
|
||||
get_repository,
|
||||
|
|
@ -66,7 +67,7 @@ def github_repositories_eval_suite() -> EvalSuite:
|
|||
"org": "ArcadeAI",
|
||||
"repo_type": "all",
|
||||
"sort": "created",
|
||||
"sort_direction": "desc",
|
||||
"sort_direction": SortDirection.DESC,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
@ -108,7 +109,7 @@ def github_repositories_eval_suite() -> EvalSuite:
|
|||
{
|
||||
"owner": "ArcadeAI",
|
||||
"repo": "test",
|
||||
"direction": "desc",
|
||||
"direction": SortDirection.DESC,
|
||||
"per_page": 30,
|
||||
"actor": "TestUser",
|
||||
"time_period": "month",
|
||||
|
|
@ -138,7 +139,7 @@ def github_repositories_eval_suite() -> EvalSuite:
|
|||
"owner": "ArcadeAI",
|
||||
"repo": "test",
|
||||
"sort": "created",
|
||||
"direction": "desc",
|
||||
"direction": SortDirection.DESC,
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"include_extra_data": False,
|
||||
|
|
|
|||
|
|
@ -1,17 +1,16 @@
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from arcade.core.errors import RetryableToolError, ToolExecutionError
|
||||
from arcade.core.errors import RetryableToolError
|
||||
from arcade.core.schema import ToolContext
|
||||
from arcade.sdk import tool
|
||||
from arcade.sdk.auth import Google
|
||||
from arcade_google.tools.models import Day, EventVisibility, SendUpdatesOptions, TimeSlot
|
||||
from arcade_google.tools.utils import _update_datetime
|
||||
from arcade_google.tools.models import EventVisibility, SendUpdatesOptions
|
||||
from arcade_google.tools.utils import parse_datetime
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
@ -25,97 +24,87 @@ from arcade_google.tools.utils import _update_datetime
|
|||
async def create_event(
|
||||
context: ToolContext,
|
||||
summary: Annotated[str, "The title of the event"],
|
||||
start_date: Annotated[Day, "The day that the event starts"],
|
||||
start_time: Annotated[TimeSlot, "The time of the day that the event starts"],
|
||||
end_date: Annotated[Day, "The day that the event ends"],
|
||||
end_time: Annotated[TimeSlot, "The time of the day that the event ends"],
|
||||
start_datetime: Annotated[
|
||||
str,
|
||||
"The datetime when the event starts in ISO 8601 format, e.g., '2024-12-31T15:30:00'.",
|
||||
],
|
||||
end_datetime: Annotated[
|
||||
str,
|
||||
"The datetime when the event ends in ISO 8601 format, e.g., '2024-12-31T17:30:00'.",
|
||||
],
|
||||
calendar_id: Annotated[
|
||||
str, "The ID of the calendar to create the event in, usually 'primary'"
|
||||
str, "The ID of the calendar to create the event in, usually 'primary'."
|
||||
] = "primary",
|
||||
description: Annotated[str | None, "The description of the event"] = None,
|
||||
location: Annotated[str | None, "The location of the event"] = None,
|
||||
visibility: Annotated[EventVisibility, "The visibility of the event"] = EventVisibility.DEFAULT,
|
||||
attendee_emails: Annotated[
|
||||
list[str] | None,
|
||||
"The list of attendee emails. Must be valid email addresses e.g., username@domain.com",
|
||||
"The list of attendee emails. Must be valid email addresses e.g., username@domain.com.",
|
||||
] = None,
|
||||
) -> Annotated[dict, "A dictionary containing the created event details"]:
|
||||
"""Create a new event/meeting/sync/meetup in the specified calendar."""
|
||||
|
||||
service = build("calendar", "v3", credentials=Credentials(context.authorization.token))
|
||||
|
||||
try:
|
||||
# Get the calendar's time zone
|
||||
calendar = service.calendars().get(calendarId=calendar_id).execute()
|
||||
time_zone = calendar["timeZone"]
|
||||
# Get the calendar's time zone
|
||||
calendar = service.calendars().get(calendarId=calendar_id).execute()
|
||||
time_zone = calendar["timeZone"]
|
||||
|
||||
# Convert enum values to datetime objects
|
||||
start_datetime = datetime.combine(start_date.to_date(time_zone), start_time.to_time())
|
||||
end_datetime = datetime.combine(end_date.to_date(time_zone), end_time.to_time())
|
||||
# Parse datetime strings
|
||||
start_dt = parse_datetime(start_datetime, time_zone)
|
||||
end_dt = parse_datetime(end_datetime, time_zone)
|
||||
|
||||
event = {
|
||||
"summary": summary,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"start": {"dateTime": start_datetime.isoformat(), "timeZone": time_zone},
|
||||
"end": {"dateTime": end_datetime.isoformat(), "timeZone": time_zone},
|
||||
"visibility": visibility.value,
|
||||
}
|
||||
event = {
|
||||
"summary": summary,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"start": {"dateTime": start_dt.isoformat(), "timeZone": time_zone},
|
||||
"end": {"dateTime": end_dt.isoformat(), "timeZone": time_zone},
|
||||
"visibility": visibility.value,
|
||||
}
|
||||
|
||||
if attendee_emails:
|
||||
event["attendees"] = [{"email": email} for email in attendee_emails]
|
||||
if attendee_emails:
|
||||
event["attendees"] = [{"email": email} for email in attendee_emails]
|
||||
|
||||
created_event = service.events().insert(calendarId=calendar_id, body=event).execute()
|
||||
|
||||
except HttpError as e:
|
||||
raise ToolExecutionError(
|
||||
f"HttpError during execution of '{create_event.__name__}' tool.", str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolExecutionError(
|
||||
f"Unexpected Error encountered during execution of '{create_event.__name__}' tool.",
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
return {"event": created_event}
|
||||
created_event = service.events().insert(calendarId=calendar_id, body=event).execute()
|
||||
return {"event": created_event}
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Google(
|
||||
scopes=["https://www.googleapis.com/auth/calendar.events.readonly"],
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
"https://www.googleapis.com/auth/calendar.events",
|
||||
],
|
||||
)
|
||||
)
|
||||
async def list_events(
|
||||
context: ToolContext,
|
||||
min_day: Annotated[
|
||||
Day, "Filter by events that end on or after this day. Combined with min_time_slot"
|
||||
min_end_datetime: Annotated[
|
||||
str,
|
||||
"Filter by events that end on or after this datetime in ISO 8601 format, e.g., '2024-09-15T09:00:00'.",
|
||||
],
|
||||
min_time_slot: Annotated[
|
||||
TimeSlot, "Filter by events that end after this time. Combined with min_day"
|
||||
],
|
||||
max_day: Annotated[
|
||||
Day, "Filter by events that start on or before this day. Combined with max_time_slot"
|
||||
],
|
||||
max_time_slot: Annotated[
|
||||
TimeSlot, "Filter by events that start before this time. Combined with max_day"
|
||||
max_start_datetime: Annotated[
|
||||
str,
|
||||
"Filter by events that start before this datetime in ISO 8601 format, e.g., '2024-09-16T17:00:00'.",
|
||||
],
|
||||
calendar_id: Annotated[str, "The ID of the calendar to list events from"] = "primary",
|
||||
max_results: Annotated[int, "The maximum number of events to return"] = 10,
|
||||
) -> Annotated[dict, "A dictionary containing the list of events"]:
|
||||
"""
|
||||
List events from the specified calendar within the given date range.
|
||||
List events from the specified calendar within the given datetime range.
|
||||
|
||||
min_day and min_time_slot are combined to form the lower bound (exclusive) for an event's end time to filter by
|
||||
max_day and max_time_slot are combined to form the upper bound (exclusive) for an event's start time to filter by
|
||||
min_end_datetime serves as the lower bound (exclusive) for an event's end time.
|
||||
max_start_datetime serves as the upper bound (exclusive) for an event's start time.
|
||||
|
||||
For example:
|
||||
If min_day is set to Day.TODAY and min_time_slot is set to TimeSlot._09:00,
|
||||
and max_day is set to Day.TOMORROW and max_time_slot is set to TimeSlot._17:00,
|
||||
If min_end_datetime is set to 2024-09-15T09:00:00 and max_start_datetime is set to 2024-09-16T17:00:00,
|
||||
the function will return events that:
|
||||
1. End after 09:00 today (exclusive)
|
||||
2. Start before 17:00 tomorrow (exclusive)
|
||||
This means an event starting at 08:00 today and ending at 10:00 today would be included,
|
||||
but an event starting at 17:00 tomorrow would not be included.
|
||||
1. End after 09:00 on September 15, 2024 (exclusive)
|
||||
2. Start before 17:00 on September 16, 2024 (exclusive)
|
||||
This means an event starting at 08:00 on September 15 and ending at 10:00 on September 15 would be included,
|
||||
but an event starting at 17:00 on September 16 would not be included.
|
||||
"""
|
||||
service = build("calendar", "v3", credentials=Credentials(context.authorization.token))
|
||||
|
||||
|
|
@ -123,23 +112,19 @@ async def list_events(
|
|||
calendar = service.calendars().get(calendarId=calendar_id).execute()
|
||||
time_zone = calendar["timeZone"]
|
||||
|
||||
# Convert enum values to datetime with timezone offset
|
||||
start_datetime = datetime.combine(
|
||||
min_day.to_date(time_zone), min_time_slot.to_time()
|
||||
).astimezone(ZoneInfo(time_zone))
|
||||
end_datetime = datetime.combine(max_day.to_date(time_zone), max_time_slot.to_time()).astimezone(
|
||||
ZoneInfo(time_zone)
|
||||
)
|
||||
# Parse datetime strings
|
||||
min_end_dt = parse_datetime(min_end_datetime, time_zone)
|
||||
max_start_dt = parse_datetime(max_start_datetime, time_zone)
|
||||
|
||||
if start_datetime > end_datetime:
|
||||
start_datetime, end_datetime = end_datetime, start_datetime
|
||||
if min_end_dt > max_start_dt:
|
||||
min_end_dt, max_start_dt = max_start_dt, min_end_dt
|
||||
|
||||
events_result = (
|
||||
service.events()
|
||||
.list(
|
||||
calendarId=calendar_id,
|
||||
timeMin=start_datetime.isoformat(),
|
||||
timeMax=end_datetime.isoformat(),
|
||||
timeMin=min_end_dt.isoformat(),
|
||||
timeMax=max_start_dt.isoformat(),
|
||||
maxResults=max_results,
|
||||
singleEvents=True,
|
||||
orderBy="startTime",
|
||||
|
|
@ -179,21 +164,16 @@ async def list_events(
|
|||
async def update_event(
|
||||
context: ToolContext,
|
||||
event_id: Annotated[str, "The ID of the event to update"],
|
||||
updated_start_day: Annotated[
|
||||
Day | None,
|
||||
"The updated day that the event starts. Combined with updated_start_time to form the new start time",
|
||||
updated_start_datetime: Annotated[
|
||||
str | None,
|
||||
"The updated datetime that the event starts in ISO 8601 format, e.g., '2024-12-31T15:30:00'.",
|
||||
] = None,
|
||||
updated_start_time: Annotated[
|
||||
TimeSlot | None,
|
||||
"The updated time that the event starts. Combined with updated_start_day to form the new start time",
|
||||
updated_end_datetime: Annotated[
|
||||
str | None,
|
||||
"The updated datetime that the event ends in ISO 8601 format, e.g., '2024-12-31T17:30:00'.",
|
||||
] = None,
|
||||
updated_end_day: Annotated[
|
||||
Day | None,
|
||||
"The updated day that the event ends. Combined with updated_end_time to form the new end time",
|
||||
] = None,
|
||||
updated_end_time: Annotated[TimeSlot | None, "The updated time that the event ends"] = None,
|
||||
updated_calendar_id: Annotated[
|
||||
str | None, "The updated ID of the calendar containing the event"
|
||||
str | None, "The updated ID of the calendar containing the event."
|
||||
] = None,
|
||||
updated_summary: Annotated[str | None, "The updated title of the event"] = None,
|
||||
updated_description: Annotated[str | None, "The updated description of the event"] = None,
|
||||
|
|
@ -201,25 +181,24 @@ async def update_event(
|
|||
updated_visibility: Annotated[EventVisibility | None, "The visibility of the event"] = None,
|
||||
attendee_emails_to_add: Annotated[
|
||||
list[str] | None,
|
||||
"The list of updated attendee emails to add. Must be valid email addresses e.g., username@domain.com",
|
||||
"The list of attendee emails to add. Must be valid email addresses e.g., username@domain.com.",
|
||||
] = None,
|
||||
attendee_emails_to_remove: Annotated[
|
||||
list[str] | None,
|
||||
"The list of attendee emails to remove. Must be valid email addresses e.g., username@domain.com",
|
||||
"The list of attendee emails to remove. Must be valid email addresses e.g., username@domain.com.",
|
||||
] = None,
|
||||
send_updates: Annotated[
|
||||
SendUpdatesOptions, "Guests who should receive notifications about the event update"
|
||||
SendUpdatesOptions, "Should attendees be notified of the update? (none, all, external_only)"
|
||||
] = SendUpdatesOptions.ALL,
|
||||
) -> Annotated[
|
||||
str,
|
||||
"A string containing the updated event details, including the event ID, update timestamp, and a link to view the updated event",
|
||||
"A string containing the updated event details, including the event ID, update timestamp, and a link to view the updated event.",
|
||||
]:
|
||||
"""
|
||||
Update an existing event in the specified calendar with the provided details.
|
||||
Only the provided fields will be updated; others will remain unchanged.
|
||||
|
||||
`updated_start_day` and `updated_start_time` must be provided together.
|
||||
`updated_end_day` and `updated_end_time` must be provided together.
|
||||
`updated_start_datetime` and `updated_end_datetime` are independent and can be provided separately.
|
||||
"""
|
||||
service = build("calendar", "v3", credentials=Credentials(context.authorization.token))
|
||||
|
||||
|
|
@ -228,13 +207,13 @@ async def update_event(
|
|||
|
||||
try:
|
||||
event = service.events().get(calendarId="primary", eventId=event_id).execute()
|
||||
except HttpError: # TODO: This is a first pass. We should do better.
|
||||
except HttpError:
|
||||
valid_events_with_id = (
|
||||
service.events()
|
||||
.list(
|
||||
calendarId="primary",
|
||||
timeMin=(datetime.now() - timedelta(days=2)).isoformat(),
|
||||
timeMax=(datetime.now() - timedelta(days=2)).isoformat(),
|
||||
timeMax=(datetime.now() + timedelta(days=365)).isoformat(),
|
||||
maxResults=50,
|
||||
singleEvents=True,
|
||||
orderBy="startTime",
|
||||
|
|
@ -243,14 +222,18 @@ async def update_event(
|
|||
)
|
||||
raise RetryableToolError(
|
||||
f"Event with ID {event_id} not found.",
|
||||
additional_prompt_content=f"Here is list of valid events. The event_id parameter must match one of these: {valid_events_with_id}",
|
||||
additional_prompt_content=f"Here is a list of valid events. The event_id parameter must match one of these: {valid_events_with_id}",
|
||||
retry_after_ms=1000,
|
||||
developer_message=f"Event with ID {event_id} not found. Please try again with a valid event ID.",
|
||||
)
|
||||
|
||||
update_fields = {
|
||||
"start": _update_datetime(updated_start_day, updated_start_time, time_zone),
|
||||
"end": _update_datetime(updated_end_day, updated_end_time, time_zone),
|
||||
"start": {"dateTime": updated_start_datetime.isoformat(), "timeZone": time_zone}
|
||||
if updated_start_datetime
|
||||
else None,
|
||||
"end": {"dateTime": updated_end_datetime.isoformat(), "timeZone": time_zone}
|
||||
if updated_end_datetime
|
||||
else None,
|
||||
"calendarId": updated_calendar_id,
|
||||
"sendUpdates": send_updates.value if send_updates else None,
|
||||
"summary": updated_summary,
|
||||
|
|
@ -265,12 +248,20 @@ async def update_event(
|
|||
event["attendees"] = [
|
||||
attendee
|
||||
for attendee in event.get("attendees", [])
|
||||
if attendee.get("email", "") not in attendee_emails_to_remove
|
||||
if attendee.get("email", "").lower()
|
||||
not in [email.lower() for email in attendee_emails_to_remove]
|
||||
]
|
||||
|
||||
if attendee_emails_to_add:
|
||||
event["attendees"] = event.get("attendees", []) + [
|
||||
{"email": email} for email in attendee_emails_to_add
|
||||
existing_emails = {
|
||||
attendee.get("email", "").lower() for attendee in event.get("attendees", [])
|
||||
}
|
||||
new_attendees = [
|
||||
{"email": email}
|
||||
for email in attendee_emails_to_add
|
||||
if email.lower() not in existing_emails
|
||||
]
|
||||
event["attendees"] = event.get("attendees", []) + new_attendees
|
||||
|
||||
updated_event = (
|
||||
service.events()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import datetime
|
||||
import re
|
||||
from base64 import urlsafe_b64decode
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
|
@ -11,6 +12,31 @@ from googleapiclient.discovery import build
|
|||
from arcade_google.tools.models import Day, TimeSlot
|
||||
|
||||
|
||||
def parse_datetime(datetime_str: str, time_zone: str) -> datetime:
|
||||
"""
|
||||
Parse a datetime string in ISO 8601 format and ensure it is timezone-aware.
|
||||
|
||||
Args:
|
||||
datetime_str (str): The datetime string to parse. Expected format: 'YYYY-MM-DDTHH:MM:SS'.
|
||||
time_zone (str): The timezone to apply if the datetime string is naive.
|
||||
|
||||
Returns:
|
||||
datetime: A timezone-aware datetime object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the datetime string is not in the correct format.
|
||||
"""
|
||||
try:
|
||||
dt = datetime.fromisoformat(datetime_str)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=ZoneInfo(time_zone))
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid datetime format: '{datetime_str}'. Expected ISO 8601 format, e.g., '2024-12-31T15:30:00'."
|
||||
) from e
|
||||
return dt
|
||||
|
||||
|
||||
class DateRange(Enum):
|
||||
TODAY = "today"
|
||||
YESTERDAY = "yesterday"
|
||||
|
|
@ -21,24 +47,24 @@ class DateRange(Enum):
|
|||
THIS_YEAR = "this_year"
|
||||
|
||||
def to_date_query(self):
|
||||
today = datetime.datetime.now()
|
||||
today = datetime.now()
|
||||
result = "after:"
|
||||
comparison_date = today
|
||||
|
||||
if self == DateRange.YESTERDAY:
|
||||
comparison_date = today - datetime.timedelta(days=1)
|
||||
comparison_date = today - timedelta(days=1)
|
||||
elif self == DateRange.LAST_7_DAYS:
|
||||
comparison_date = today - datetime.timedelta(days=7)
|
||||
comparison_date = today - timedelta(days=7)
|
||||
elif self == DateRange.LAST_30_DAYS:
|
||||
comparison_date = today - datetime.timedelta(days=30)
|
||||
comparison_date = today - timedelta(days=30)
|
||||
elif self == DateRange.THIS_MONTH:
|
||||
comparison_date = today.replace(day=1)
|
||||
elif self == DateRange.LAST_MONTH:
|
||||
comparison_date = (today.replace(day=1) - datetime.timedelta(days=1)).replace(day=1)
|
||||
comparison_date = (today.replace(day=1) - timedelta(days=1)).replace(day=1)
|
||||
elif self == DateRange.THIS_YEAR:
|
||||
comparison_date = today.replace(month=1, day=1)
|
||||
elif self == DateRange.LAST_MONTH:
|
||||
comparison_date = (today.replace(month=1, day=1) - datetime.timedelta(days=1)).replace(
|
||||
comparison_date = (today.replace(month=1, day=1) - timedelta(days=1)).replace(
|
||||
month=1, day=1
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,19 @@
|
|||
from datetime import timedelta
|
||||
|
||||
import arcade_google
|
||||
from arcade_google.tools.calendar import create_event, delete_event, list_events, update_event
|
||||
from arcade_google.tools.models import Day, EventVisibility, TimeSlot
|
||||
from arcade_google.tools.calendar import (
|
||||
EventVisibility,
|
||||
SendUpdatesOptions,
|
||||
create_event,
|
||||
delete_event,
|
||||
list_events,
|
||||
update_event,
|
||||
)
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
DatetimeCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
tool_eval,
|
||||
|
|
@ -16,11 +25,9 @@ rubric = EvalRubric(
|
|||
warn_threshold=0.95,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_module(arcade_google)
|
||||
|
||||
|
||||
history_after_list_events = [
|
||||
{"role": "user", "content": "do i have any events on my calendar for today?"},
|
||||
{
|
||||
|
|
@ -36,7 +43,7 @@ history_after_list_events = [
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": "Google_ListEvents",
|
||||
"arguments": '{"max_day":"today","max_time_slot":"23:45","min_day":"today","min_time_slot":"00:00"}',
|
||||
"arguments": '{"min_end_datetime":"2024-09-26T00:00:00-07:00","max_start_datetime":"2024-09-27T00:00:00-07:00"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
|
|
@ -59,7 +66,9 @@ def calendar_eval_suite() -> EvalSuite:
|
|||
"""Create an evaluation suite for Calendar tools."""
|
||||
suite = EvalSuite(
|
||||
name="Calendar Tools Evaluation",
|
||||
system_message="You are an AI assistant that can create and list events using the provided tools.",
|
||||
system_message=(
|
||||
"You are an AI assistant that can create, list, update, and delete events using the provided tools. Today is 2024-09-26"
|
||||
),
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
|
@ -67,32 +76,34 @@ def calendar_eval_suite() -> EvalSuite:
|
|||
# Cases for create_event
|
||||
suite.add_case(
|
||||
name="Create calendar event",
|
||||
user_message="Create a meeting for 'Team Meeting' starting next thursday from 11:45pm to 12:15am. Invite johndoe@example.com",
|
||||
user_message=(
|
||||
"Create a meeting for 'Team Meeting' starting on September 26, 2024, from 11:45pm to 12:15am. Invite johndoe@example.com"
|
||||
),
|
||||
expected_tool_calls=[
|
||||
(
|
||||
create_event,
|
||||
{
|
||||
"summary": "Team Meeting",
|
||||
"start_date": Day.NEXT_THURSDAY.value,
|
||||
"start_time": TimeSlot._2345.value,
|
||||
"end_date": Day.NEXT_FRIDAY.value,
|
||||
"end_time": TimeSlot._0015.value,
|
||||
"start_datetime": "2024-09-26T23:45:00",
|
||||
"end_datetime": "2024-09-27T00:15:00",
|
||||
"calendar_id": "primary",
|
||||
"attendee_emails": ["johndoe@example.com"],
|
||||
"description": None,
|
||||
"location": None,
|
||||
"visibility": EventVisibility.DEFAULT,
|
||||
"description": "Team Meeting",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="summary", weight=0.15),
|
||||
BinaryCritic(critic_field="start_date", weight=0.15),
|
||||
BinaryCritic(critic_field="start_time", weight=0.15),
|
||||
BinaryCritic(critic_field="end_date", weight=0.15),
|
||||
BinaryCritic(critic_field="end_time", weight=0.15),
|
||||
BinaryCritic(critic_field="attendee_emails", weight=0.15),
|
||||
BinaryCritic(critic_field="summary", weight=0.2),
|
||||
DatetimeCritic(
|
||||
critic_field="start_datetime", weight=0.2, tolerance=timedelta(seconds=10)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="end_datetime", weight=0.2, tolerance=timedelta(seconds=10)
|
||||
),
|
||||
BinaryCritic(critic_field="attendee_emails", weight=0.2),
|
||||
BinaryCritic(critic_field="description", weight=0.1),
|
||||
BinaryCritic(critic_field="location", weight=0.1),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -104,49 +115,61 @@ def calendar_eval_suite() -> EvalSuite:
|
|||
(
|
||||
list_events,
|
||||
{
|
||||
"min_day": Day.TODAY.value,
|
||||
"min_time_slot": TimeSlot._0000.value,
|
||||
"max_day": Day.TOMORROW.value,
|
||||
"max_time_slot": TimeSlot._0000.value,
|
||||
"min_end_datetime": "2024-09-26T00:00:00",
|
||||
"max_start_datetime": "2024-09-27T00:00:00",
|
||||
"calendar_id": "primary",
|
||||
"event_types": None,
|
||||
"max_results": 10,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="min_day", weight=0.1),
|
||||
BinaryCritic(critic_field="min_time_slot", weight=0.1),
|
||||
BinaryCritic(critic_field="max_day", weight=0.1),
|
||||
BinaryCritic(critic_field="max_time_slot", weight=0.1),
|
||||
BinaryCritic(critic_field="calendar_id", weight=0.1),
|
||||
BinaryCritic(critic_field="event_types", weight=0.1),
|
||||
BinaryCritic(critic_field="max_results", weight=0.1),
|
||||
DatetimeCritic(
|
||||
critic_field="min_end_datetime", weight=0.3, tolerance=timedelta(hours=1)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="max_start_datetime", weight=0.3, tolerance=timedelta(hours=1)
|
||||
),
|
||||
BinaryCritic(critic_field="calendar_id", weight=0.2),
|
||||
BinaryCritic(critic_field="max_results", weight=0.2),
|
||||
],
|
||||
)
|
||||
|
||||
# Cases for update_event
|
||||
suite.add_case(
|
||||
name="Update a calendar event",
|
||||
user_message="Oh no! I cant make it to the API Test since i have lunch with an old friend at that time. Change the meeting to 3pm to 4pm please.",
|
||||
user_message=(
|
||||
"Oh no! I can't make it to the API Test since I have lunch with an old friend at that time. "
|
||||
"Change the meeting my meeting tomorrow at 3pm to 4pm. Let everyone know."
|
||||
),
|
||||
expected_tool_calls=[
|
||||
(
|
||||
update_event,
|
||||
{
|
||||
"event_id": "00099992228181818181",
|
||||
"updated_start_day": Day.TODAY.value,
|
||||
"updated_start_time": TimeSlot._1500.value,
|
||||
"updated_end_day": Day.TODAY.value,
|
||||
"updated_end_time": TimeSlot._1600.value,
|
||||
"updated_start_datetime": "2024-09-27T16:00:00",
|
||||
"updated_end_datetime": "2024-09-27T18:00:00",
|
||||
"updated_calendar_id": "primary",
|
||||
"updated_summary": "API Test",
|
||||
"updated_description": "API Test",
|
||||
"updated_location": "611 Gateway Blvd",
|
||||
"updated_visibility": EventVisibility.DEFAULT,
|
||||
"attendee_emails_to_add": None,
|
||||
"attendee_emails_to_remove": None,
|
||||
"send_updates": SendUpdatesOptions.ALL,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="event_id", weight=0.2),
|
||||
BinaryCritic(critic_field="updated_start_day", weight=0.1),
|
||||
BinaryCritic(critic_field="updated_start_time", weight=0.1),
|
||||
BinaryCritic(critic_field="updated_end_day", weight=0.1),
|
||||
BinaryCritic(critic_field="updated_end_time", weight=0.1),
|
||||
BinaryCritic(critic_field="event_id", weight=0.4),
|
||||
DatetimeCritic(
|
||||
critic_field="updated_start_datetime", weight=0.2, tolerance=timedelta(minutes=15)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="updated_end_datetime",
|
||||
weight=0.2,
|
||||
tolerance=timedelta(minutes=15),
|
||||
),
|
||||
BinaryCritic(critic_field="send_updates", weight=0.2),
|
||||
],
|
||||
additional_messages=history_after_list_events,
|
||||
)
|
||||
|
|
@ -154,14 +177,16 @@ def calendar_eval_suite() -> EvalSuite:
|
|||
# Cases for delete_event
|
||||
suite.add_case(
|
||||
name="Delete a calendar event",
|
||||
user_message="I don't need to have focus time today. Please delete it from my calendar. Don't send any notifications.",
|
||||
user_message=(
|
||||
"I don't need to have focus time today. Please delete it from my calendar. Don't send any notifications."
|
||||
),
|
||||
expected_tool_calls=[
|
||||
(
|
||||
delete_event,
|
||||
{
|
||||
"event_id": "gr5g18lf88tfpp3vkareukkc7g",
|
||||
"calendar_id": "primary",
|
||||
"send_updates": "none",
|
||||
"send_updates": SendUpdatesOptions.NONE,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from arcade_google.tools.calendar import create_event, delete_event, list_events, update_event
|
||||
from arcade_google.tools.models import Day, EventVisibility, SendUpdatesOptions, TimeSlot
|
||||
from arcade_google.tools.models import EventVisibility, SendUpdatesOptions
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from arcade.core.errors import ToolExecutionError
|
||||
|
|
@ -24,7 +24,7 @@ async def test_create_event(mock_build, mock_context):
|
|||
# Mock the calendar's time zone
|
||||
mock_service.calendars().get().execute.return_value = {"timeZone": "America/Los_Angeles"}
|
||||
|
||||
# Case: HttpError
|
||||
# Case: HttpError during event creation
|
||||
mock_service.events().insert().execute.side_effect = HttpError(
|
||||
resp=MagicMock(status=400),
|
||||
content=b'{"error": {"message": "Invalid request"}}',
|
||||
|
|
@ -34,10 +34,8 @@ async def test_create_event(mock_build, mock_context):
|
|||
await create_event(
|
||||
context=mock_context,
|
||||
summary="Test Event",
|
||||
start_date=Day.TODAY,
|
||||
start_time=TimeSlot._1615,
|
||||
end_date=Day.TODAY,
|
||||
end_time=TimeSlot._1715,
|
||||
start_datetime="2024-12-31T15:30:00",
|
||||
end_datetime="2024-12-31T17:30:00",
|
||||
description="Test Description",
|
||||
location="Test Location",
|
||||
visibility=EventVisibility.PRIVATE,
|
||||
|
|
@ -53,34 +51,34 @@ async def test_list_events(mock_build, mock_context):
|
|||
# Mock the calendar's time zone
|
||||
mock_service.calendars().get().execute.return_value = {"timeZone": "America/Los_Angeles"}
|
||||
|
||||
# Case: min time is after max time. list_events tool should swap the times and still return the events
|
||||
# Mock the events list response
|
||||
mock_events_list_response = {
|
||||
"items": [
|
||||
{
|
||||
"creator": {"email": "example@arcade-ai.com", "self": True},
|
||||
"end": {"dateTime": "2024-09-27T01:00:00-07:00", "timeZone": "America/Los_Angeles"},
|
||||
"eventType": "default",
|
||||
"htmlLink": "https://www.google.com/calendar/event?eid=N2pmYjZ0ZmNnMGNydG5scmhkY2JvZWc4OGIgZXJpY0BhcmNhZGUtYWku",
|
||||
"id": "7jfb6tfcg0crtnlrhdcboeg88b",
|
||||
"htmlLink": "https://www.google.com/calendar/event?eid=event1",
|
||||
"id": "event1",
|
||||
"organizer": {"email": "example@arcade-ai.com", "self": True},
|
||||
"start": {
|
||||
"dateTime": "2024-09-27T00:00:00-07:00",
|
||||
"timeZone": "America/Los_Angeles",
|
||||
},
|
||||
"summary": "teST",
|
||||
"summary": "Event 1",
|
||||
},
|
||||
{
|
||||
"creator": {"email": "example@arcade-ai.com", "self": True},
|
||||
"end": {"dateTime": "2024-09-27T17:00:00-07:00", "timeZone": "America/Los_Angeles"},
|
||||
"eventType": "default",
|
||||
"htmlLink": "https://www.google.com/calendar/event?eid=MjZvYnRoc2xtMWMzbG5mdG10bzk4cDcxaGMgZXJpY0BhcmNhZGUtYWku",
|
||||
"id": "26obthslm1c3lnftmto98p71hc",
|
||||
"htmlLink": "https://www.google.com/calendar/event?eid=event2",
|
||||
"id": "event2",
|
||||
"organizer": {"email": "example@arcade-ai.com", "self": True},
|
||||
"start": {
|
||||
"dateTime": "2024-09-27T14:00:00-07:00",
|
||||
"timeZone": "America/Los_Angeles",
|
||||
},
|
||||
"summary": "New Event",
|
||||
"summary": "Event 2",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
|
@ -89,16 +87,14 @@ async def test_list_events(mock_build, mock_context):
|
|||
"events": mock_events_list_response["items"],
|
||||
}
|
||||
mock_service.events().list().execute.return_value = mock_events_list_response
|
||||
message = await list_events(
|
||||
response = await list_events(
|
||||
context=mock_context,
|
||||
min_day=Day.TODAY,
|
||||
min_time_slot=TimeSlot._1615,
|
||||
max_day=Day.TODAY,
|
||||
max_time_slot=TimeSlot._1515,
|
||||
min_end_datetime="2024-09-15T09:00:00",
|
||||
max_start_datetime="2024-09-16T17:00:00",
|
||||
)
|
||||
assert message == expected_tool_response
|
||||
assert response == expected_tool_response
|
||||
|
||||
# Case: HttpError
|
||||
# Case: HttpError during events listing
|
||||
mock_service.events().list().execute.side_effect = HttpError(
|
||||
resp=MagicMock(status=400),
|
||||
content=b'{"error": {"message": "Invalid request"}}',
|
||||
|
|
@ -107,10 +103,8 @@ async def test_list_events(mock_build, mock_context):
|
|||
with pytest.raises(ToolExecutionError):
|
||||
await list_events(
|
||||
context=mock_context,
|
||||
min_day=Day.TODAY,
|
||||
min_time_slot=TimeSlot._1615,
|
||||
max_day=Day.TOMORROW,
|
||||
max_time_slot=TimeSlot._1815,
|
||||
min_end_datetime="2024-09-15T09:00:00",
|
||||
max_start_datetime="2024-09-16T17:00:00",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -119,8 +113,10 @@ async def test_list_events(mock_build, mock_context):
|
|||
async def test_update_event(mock_build, mock_context):
|
||||
mock_service = MagicMock()
|
||||
mock_build.return_value = mock_service
|
||||
mock_service.events().update().execute.side_effect = HttpError(
|
||||
resp=MagicMock(status=400),
|
||||
|
||||
# Mock retrieval of the event
|
||||
mock_service.events().get().execute.side_effect = HttpError(
|
||||
resp=MagicMock(status=404),
|
||||
content=b'{"error": {"message": "Event not found"}}',
|
||||
)
|
||||
|
||||
|
|
@ -128,10 +124,8 @@ async def test_update_event(mock_build, mock_context):
|
|||
await update_event(
|
||||
context=mock_context,
|
||||
event_id="1234567890",
|
||||
updated_start_day=Day.NEXT_FRIDAY,
|
||||
updated_start_time=TimeSlot._0015,
|
||||
updated_end_day=Day.NEXT_FRIDAY,
|
||||
updated_end_time=TimeSlot._0115,
|
||||
updated_start_datetime="2024-12-31T00:15:00",
|
||||
updated_end_datetime="2024-12-31T01:15:00",
|
||||
updated_summary="Updated Event",
|
||||
updated_description="Updated Description",
|
||||
updated_location="Updated Location",
|
||||
|
|
@ -148,7 +142,7 @@ async def test_delete_event(mock_build, mock_context):
|
|||
mock_service = MagicMock()
|
||||
mock_build.return_value = mock_service
|
||||
mock_service.events().delete().execute.side_effect = HttpError(
|
||||
resp=MagicMock(status=400),
|
||||
resp=MagicMock(status=404),
|
||||
content=b'{"error": {"message": "Event not found"}}',
|
||||
)
|
||||
|
||||
|
|
@ -156,4 +150,5 @@ async def test_delete_event(mock_build, mock_context):
|
|||
await delete_event(
|
||||
context=mock_context,
|
||||
event_id="nonexistent_event",
|
||||
send_updates=SendUpdatesOptions.ALL,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from arcade_math.tools.arithmetic import (
|
|||
sum_range,
|
||||
)
|
||||
|
||||
from arcade.sdk.error import ToolExecutionError
|
||||
|
||||
|
||||
def test_add():
|
||||
assert add(1, 2) == 3
|
||||
|
|
@ -29,7 +31,7 @@ def test_multiply():
|
|||
def test_divide():
|
||||
assert divide(6, 3) == 2.0
|
||||
assert divide(5, 2) == 2.5
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
with pytest.raises(ToolExecutionError):
|
||||
divide(1, 0)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue