Tool Evalulation SDK (#35)
1. New Eval SDK (`arcade/sdk/eval.py`): - Introduces `EvalSuite`, `EvalCase`, and `EvalRubric` classes for structured evaluation. - Implements various Critic classes (Binary, Numeric, Similarity) for flexible scoring. - Adds a `tool_eval` decorator for easy integration with existing tools. 2. CLI Integration (`arcade/cli/main.py` and `arcade/cli/utils.py`): - Adds an `evals` command to run evaluation suites from the CLI. - Implements result display functionality for evaluation outcomes. 3. Toolkit Updates: - Adds evaluation scripts for Gmail ([toolkits/gmail/evals/eval_gmail_tools.py](file:///Users/spartee/Dropbox/Arcade/platform/Team/arcade-ai/toolkits/gmail/evals/eval_gmail_tools.py#1%2C1-1%2C1)) and Slack ([toolkits/slack/evals/eval_slack_messaging.py](file:///Users/spartee/Dropbox/Arcade/platform/Team/arcade-ai/toolkits/slack/evals/eval_slack_messaging.py#1%2C1-1%2C1)) toolkits. - Demonstrates practical usage of the Eval SDK with real-world scenarios. 4. Miscellaneous: - Updates `arcade/cli/new.py` to optionally generate an `evals` directory for new toolkits. --------- Co-authored-by: Nate Barbettini <nate@arcade-ai.com>
This commit is contained in:
parent
e4839195d7
commit
db948125d5
17 changed files with 1768 additions and 12 deletions
32
Makefile
32
Makefile
|
|
@ -3,9 +3,8 @@
|
|||
.PHONY: install
|
||||
install: ## Install the poetry environment and install the pre-commit hooks
|
||||
@echo "🚀 Creating virtual environment using pyenv and poetry"
|
||||
@cd arcade && poetry install
|
||||
@cd arcade && poetry install --all-extras
|
||||
@cd arcade && poetry run pre-commit install
|
||||
@cd arcade && poetry shell
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
|
|
@ -54,6 +53,35 @@ docker: ## Build and run the Docker container
|
|||
@cd docker && make docker-build
|
||||
@cd docker && make docker-run
|
||||
|
||||
.PHONY: full-dist
|
||||
full-dist: clean-dist ## Build all projects and copy wheels to arcade/dist
|
||||
@echo "🚀 Building all projects and copying wheels to arcade/dist"
|
||||
|
||||
# Build the main arcade project
|
||||
@echo "Building arcade project..."
|
||||
@cd arcade && poetry build
|
||||
|
||||
# Create the arcade/dist directory if it doesn't exist
|
||||
@mkdir -p arcade/dist
|
||||
|
||||
# Build and copy wheels for each toolkit
|
||||
@for toolkit_dir in toolkits/*; do \
|
||||
if [ -d "$$toolkit_dir" ]; then \
|
||||
toolkit_name=$$(basename "$$toolkit_dir"); \
|
||||
echo "Building $$toolkit_name project..."; \
|
||||
cd "$$toolkit_dir" && poetry build; \
|
||||
cp dist/*.whl ../../arcade/dist; \
|
||||
cd -; \
|
||||
fi; \
|
||||
done
|
||||
|
||||
@echo "✅ All projects built and wheels copied to arcade/dist"
|
||||
|
||||
.PHONY: clean-dist
|
||||
clean-dist: ## Clean the arcade/dist directory
|
||||
@echo "🗑️ Cleaning arcade/dist directory"
|
||||
@rm -rf arcade/dist
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@echo "🛠️ Arcade AI Dev Commands:\n"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import importlib.util
|
||||
import os
|
||||
import readline
|
||||
import threading
|
||||
|
|
@ -18,6 +19,7 @@ from arcade.cli.utils import (
|
|||
OrderCommands,
|
||||
apply_config_overrides,
|
||||
create_cli_catalog,
|
||||
display_eval_results,
|
||||
display_streamed_markdown,
|
||||
markdownify_urls,
|
||||
validate_and_get_config,
|
||||
|
|
@ -107,6 +109,7 @@ def show(
|
|||
None, "-t", "--toolkit", help="The toolkit to show the tools of"
|
||||
),
|
||||
actor: Optional[str] = typer.Option(None, help="A running actor address to list tools from"),
|
||||
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
|
||||
) -> None:
|
||||
"""
|
||||
Show the available tools in an actor or toolkit
|
||||
|
|
@ -128,7 +131,8 @@ def show(
|
|||
console.print(table)
|
||||
|
||||
except Exception as e:
|
||||
# better error message here
|
||||
if debug:
|
||||
raise
|
||||
error_message = f"❌ Failed to List tools: {escape(str(e))}"
|
||||
console.print(error_message, style="bold red")
|
||||
|
||||
|
|
@ -380,3 +384,55 @@ def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
|
|||
table.add_row("", "", "")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@cli.command(help="Run evaluation suites in a directory")
|
||||
def evals(
|
||||
directory: str = typer.Argument(".", help="Directory containing evaluation files"),
|
||||
show_details: bool = typer.Option(False, "--details", "-d", help="Show detailed results"),
|
||||
max_concurrent: int = typer.Option(
|
||||
1,
|
||||
"--max-concurrent",
|
||||
"-c",
|
||||
help="Maximum number of concurrent evaluations (default: 1)",
|
||||
),
|
||||
models: str = typer.Option(
|
||||
"gpt-4o", "--models", "-m", help="The models to use for evaluation (default: gpt-4o)"
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Find all files starting with 'eval_' in the given directory,
|
||||
execute any functions decorated with @tool_eval, and display the results.
|
||||
"""
|
||||
models = models.split(",") # type: ignore[assignment]
|
||||
eval_files = [f for f in os.listdir(directory) if f.startswith("eval_") and f.endswith(".py")]
|
||||
|
||||
if not eval_files:
|
||||
console.print("No evaluation files found.", style="bold yellow")
|
||||
return
|
||||
|
||||
for file in eval_files:
|
||||
file_path = os.path.join(directory, file)
|
||||
module_name = file[:-3] # Remove .py extension
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec is None:
|
||||
console.print(f"Failed to load {file}", style="bold red")
|
||||
continue
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
|
||||
eval_functions = [
|
||||
obj
|
||||
for name, obj in module.__dict__.items()
|
||||
if callable(obj) and hasattr(obj, "__tool_eval__")
|
||||
]
|
||||
|
||||
if not eval_functions:
|
||||
console.print(f"No @tool_eval functions found in {file}", style="bold yellow")
|
||||
continue
|
||||
|
||||
for func in eval_functions:
|
||||
console.print(f"\nRunning evaluation from {file}: {func.__name__}", style="bold blue")
|
||||
results = func(models=models, max_concurrency=max_concurrent)
|
||||
display_eval_results(results, show_details=show_details)
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ def create_new_toolkit(directory: str) -> None:
|
|||
author = f"{author_name} <{author_email}>"
|
||||
|
||||
generate_test_dir = ask_question("Generate test directory? (yes/no)", "yes") == "yes"
|
||||
generate_eval_dir = ask_question("Generate eval directory? (yes/no)", "yes") == "yes"
|
||||
|
||||
top_level_dir = os.path.join(directory, name)
|
||||
toolkit_dir = os.path.join(directory, name, toolkit_name)
|
||||
|
|
@ -140,4 +141,8 @@ def create_new_toolkit(directory: str) -> None:
|
|||
if generate_test_dir:
|
||||
create_directory(os.path.join(top_level_dir, "tests"))
|
||||
|
||||
# If the user wants to generate an eval directory
|
||||
if generate_eval_dir:
|
||||
create_directory(os.path.join(top_level_dir, "evals"))
|
||||
|
||||
console.print(f"[green]Toolkit {toolkit_name} has been created.[/green]")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import typer
|
||||
from openai.resources.chat.completions import ChatCompletionChunk, Stream
|
||||
from rich.console import Console
|
||||
|
|
@ -9,6 +11,9 @@ from arcade.core.catalog import ToolCatalog
|
|||
from arcade.core.config_model import Config
|
||||
from arcade.core.toolkit import Toolkit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arcade.sdk.eval.eval import EvaluationResult
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
|
|
@ -150,3 +155,91 @@ def apply_config_overrides(
|
|||
|
||||
if tls_input is not None:
|
||||
config.engine.tls = tls_input
|
||||
|
||||
|
||||
def display_eval_results(results: list[dict[str, Any]], show_details: bool = False) -> None:
|
||||
"""
|
||||
Display evaluation results in a format inspired by pytest's output.
|
||||
|
||||
Args:
|
||||
results: List of dictionaries containing evaluation results for each model.
|
||||
show_details: Whether to show detailed results for each case.
|
||||
"""
|
||||
total_passed = 0
|
||||
total_failed = 0
|
||||
total_warned = 0
|
||||
total_cases = 0
|
||||
|
||||
for model_results in results:
|
||||
model = model_results.get("model", "Unknown Model")
|
||||
rubric = model_results.get("rubric", "Unknown Rubric")
|
||||
cases = model_results.get("cases", [])
|
||||
total_cases += len(cases)
|
||||
|
||||
console.print(f"\n[bold magenta]Model: {model}[/bold magenta]\n")
|
||||
console.print(f"[bold magenta]{rubric}[/bold magenta]\n")
|
||||
|
||||
for case in cases:
|
||||
evaluation = case["evaluation"]
|
||||
status = (
|
||||
"[green]PASSED[/green]"
|
||||
if evaluation.passed
|
||||
else "[yellow]WARNED[/yellow]"
|
||||
if evaluation.warning
|
||||
else "[red]FAILED[/red]"
|
||||
)
|
||||
if evaluation.passed:
|
||||
total_passed += 1
|
||||
elif evaluation.warning:
|
||||
total_warned += 1
|
||||
else:
|
||||
total_failed += 1
|
||||
|
||||
# Display one-line summary for each case
|
||||
console.print(f"{status} {case['name']} -- Score: {evaluation.score:.2f}")
|
||||
|
||||
if show_details:
|
||||
# Show detailed information for each case
|
||||
console.print(f"[bold]User Input:[/bold] {case['input']}\n")
|
||||
console.print("[bold]Details:[/bold]")
|
||||
console.print(_format_evaluation(evaluation))
|
||||
console.print("-" * 80)
|
||||
|
||||
# Summary
|
||||
console.print("\n[bold]Summary:[/bold]")
|
||||
console.print(f"Total Cases: {total_cases}")
|
||||
console.print(f"[green]Passed: {total_passed}[/green]")
|
||||
console.print(f"[yellow]Warnings: {total_warned}[/yellow]")
|
||||
console.print(f"[red]Failed: {total_failed}[/red]\n")
|
||||
|
||||
|
||||
def _format_evaluation(evaluation: "EvaluationResult") -> str:
|
||||
"""
|
||||
Format evaluation results with color-coded matches and scores.
|
||||
|
||||
Args:
|
||||
evaluation: An EvaluationResult object containing the evaluation results.
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the evaluation details.
|
||||
"""
|
||||
result_lines = []
|
||||
|
||||
# Include overall final score
|
||||
result_lines.append(f"[bold]Final Score:[/bold] {evaluation.score:.2f}\n")
|
||||
|
||||
for critic_result in evaluation.results:
|
||||
match_color = "green" if critic_result["match"] else "red"
|
||||
field = critic_result["field"]
|
||||
score = critic_result["score"]
|
||||
weight = critic_result["weight"]
|
||||
expected = critic_result["expected"]
|
||||
actual = critic_result["actual"]
|
||||
result_lines.append(
|
||||
f"[bold]{field}:[/bold] "
|
||||
f"[{match_color}]Match: {critic_result['match']}, "
|
||||
f"Score: {score:.2f}/{weight:.2f}[/{match_color}]"
|
||||
f"\n Expected: {expected}"
|
||||
f"\n Actual: {actual}"
|
||||
)
|
||||
return "\n".join(result_lines)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import typing
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
|
@ -501,6 +502,7 @@ def get_wire_type(
|
|||
"""
|
||||
Mapping between Python types and HTTP/JSON types
|
||||
"""
|
||||
# TODO ensure Any is not allowed
|
||||
type_mapping: dict[type, WireType] = {
|
||||
str: "string",
|
||||
bool: "boolean",
|
||||
|
|
@ -513,7 +515,6 @@ def get_wire_type(
|
|||
list: "array",
|
||||
dict: "json",
|
||||
}
|
||||
|
||||
wire_type = type_mapping.get(_type)
|
||||
if wire_type:
|
||||
return wire_type
|
||||
|
|
@ -580,6 +581,17 @@ def determine_output_model(func: Callable) -> type[BaseModel]:
|
|||
output_model_name,
|
||||
result=(field_type, Field(description=str(description))),
|
||||
)
|
||||
# Handle Union types
|
||||
origin = return_annotation.__origin__
|
||||
if origin is typing.Union:
|
||||
# For union types, create a model with the first non-None argument
|
||||
# TODO handle multiple non-None arguments. Raise error?
|
||||
for arg in get_args(return_annotation):
|
||||
if arg is not type(None):
|
||||
return create_model(
|
||||
output_model_name,
|
||||
result=(arg, Field(description="No description provided.")),
|
||||
)
|
||||
# when the return_annotation has an __origin__ attribute
|
||||
# and does not have a __metadata__ attribute.
|
||||
return create_model(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,21 @@
|
|||
from .eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
NumericCritic,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
from .tool import tool
|
||||
|
||||
__all__ = [
|
||||
"tool",
|
||||
"EvalRubric",
|
||||
"EvalSuite",
|
||||
"ExpectedToolCall",
|
||||
"tool_eval",
|
||||
"BinaryCritic",
|
||||
"SimilarityCritic",
|
||||
"NumericCritic",
|
||||
]
|
||||
|
|
|
|||
6
arcade/arcade/sdk/error.py
Normal file
6
arcade/arcade/sdk/error.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
class SDKError(Exception):
|
||||
"""Base class for all SDK errors."""
|
||||
|
||||
|
||||
class WeightError(SDKError):
|
||||
"""Raised when the critic weights do not abide by SDK weight constraints."""
|
||||
12
arcade/arcade/sdk/eval/__init__.py
Normal file
12
arcade/arcade/sdk/eval/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from .critic import BinaryCritic, NumericCritic, SimilarityCritic
|
||||
from .eval import EvalRubric, EvalSuite, ExpectedToolCall, tool_eval
|
||||
|
||||
__all__ = [
|
||||
"BinaryCritic",
|
||||
"SimilarityCritic",
|
||||
"NumericCritic",
|
||||
"EvalRubric",
|
||||
"EvalSuite",
|
||||
"ExpectedToolCall",
|
||||
"tool_eval",
|
||||
]
|
||||
154
arcade/arcade/sdk/eval/critic.py
Normal file
154
arcade/arcade/sdk/eval/critic.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from arcade.sdk.error import WeightError
|
||||
|
||||
|
||||
@dataclass
|
||||
class Critic(ABC):
|
||||
critic_field: str
|
||||
weight: float
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.weight < 0 or self.weight > 1:
|
||||
raise WeightError(f"Critic weight must be between 0 and 1, got {self.weight}")
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, expected: Any, actual: Any) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryCritic(Critic):
|
||||
"""
|
||||
A critic for performing exact equality comparisons between expected and actual values.
|
||||
|
||||
This critic evaluates whether the expected and actual values are exactly equal.
|
||||
It's useful for scenarios where only an exact match is acceptable.
|
||||
|
||||
Returns:
|
||||
A dict with:
|
||||
- "match": True if expected == actual, otherwise False.
|
||||
- "score": The full weight if there's a match, otherwise 0.0.
|
||||
"""
|
||||
|
||||
def evaluate(self, expected: Any, actual: Any) -> dict[str, float | bool]:
|
||||
match = expected == actual
|
||||
return {"match": match, "score": self.weight if match else 0.0}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumericCritic(Critic):
|
||||
"""
|
||||
A critic for evaluating numeric values within a specified range.
|
||||
|
||||
This critic performs a "fuzzy" comparison of numeric values, where values closer
|
||||
to each other (relative to the specified range) result in higher scores. It's
|
||||
useful for scenarios where exact matches aren't necessary, but closeness within
|
||||
a certain tolerance is rewarded.
|
||||
|
||||
Attributes:
|
||||
value_range: The min and max values of the expected range.
|
||||
match_threshold: The threshold for considering a match (default 0.8).
|
||||
|
||||
The evaluation process:
|
||||
1. Normalizes both expected and actual values to a 0-1 scale based on value_range.
|
||||
2. Calculates the absolute difference between these normalized values.
|
||||
3. Subtracts this difference from 1 to get a similarity score (closer to 1 is more similar).
|
||||
4. Multiplies the similarity by the critic's weight for the final score.
|
||||
|
||||
Returns:
|
||||
A dict with:
|
||||
- "match": True if the score >= match_threshold, otherwise False.
|
||||
- "score": The calculated score (similarity * weight).
|
||||
"""
|
||||
|
||||
value_range: tuple[float, float]
|
||||
match_threshold: float = 0.8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
critic_field: str,
|
||||
weight: float,
|
||||
value_range: tuple[float, float],
|
||||
match_threshold: float = 0.8,
|
||||
):
|
||||
super().__init__(critic_field, weight)
|
||||
if value_range[0] >= value_range[1]:
|
||||
raise ValueError("Invalid value_range: minimum must be less than maximum.")
|
||||
self.value_range = value_range
|
||||
self.match_threshold = match_threshold
|
||||
|
||||
def evaluate(self, expected: Any, actual: Any) -> dict[str, Any]:
|
||||
min_val, max_val = self.value_range
|
||||
normalized_expected = float((float(expected) - min_val) / (max_val - min_val))
|
||||
normalized_actual = float((float(actual) - min_val) / (max_val - min_val))
|
||||
score = float(1 - abs(normalized_expected - normalized_actual))
|
||||
return {"match": bool(score >= self.match_threshold), "score": float(score * self.weight)}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimilarityCritic(Critic):
|
||||
"""
|
||||
A critic for evaluating the similarity between two strings.
|
||||
|
||||
This critic uses a specified similarity metric to compare the expected and actual
|
||||
string values. Currently, it supports cosine similarity using TF-IDF vectorization.
|
||||
|
||||
Args:
|
||||
metric: The similarity metric to use (default is "cosine").
|
||||
similarity_threshold: The threshold for considering a match (default 0.8).
|
||||
|
||||
The evaluation process:
|
||||
1. Converts both expected and actual values to strings.
|
||||
2. Calculates the similarity score using the specified metric.
|
||||
3. Determines a match based on the similarity_threshold.
|
||||
4. Calculates the final score by multiplying the similarity by the critic's weight.
|
||||
|
||||
Returns:
|
||||
A dict with:
|
||||
- "match": True if similarity >= similarity_threshold, otherwise False.
|
||||
- "score": The calculated score (similarity * weight).
|
||||
|
||||
Raises:
|
||||
ImportError: If scikit-learn is not installed (required for cosine similarity).
|
||||
ValueError: If an unsupported similarity metric is specified.
|
||||
"""
|
||||
|
||||
metric: str = "cosine"
|
||||
similarity_threshold: float = 0.8
|
||||
|
||||
SUPPORTED_METRICS: ClassVar[list[str]] = ["cosine"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
critic_field: str,
|
||||
weight: float,
|
||||
similarity_threshold: float = 0.8,
|
||||
metric: str = "cosine",
|
||||
):
|
||||
super().__init__(critic_field, weight)
|
||||
if metric not in self.SUPPORTED_METRICS:
|
||||
raise ValueError(f"Unsupported similarity metric: {metric}")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.metric = metric
|
||||
|
||||
def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]:
|
||||
if self.metric == "cosine":
|
||||
try:
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Use `pip install arcade[evals]` to install the required dependencies for similarity metrics."
|
||||
)
|
||||
vectorizer = TfidfVectorizer()
|
||||
tfidf_matrix = vectorizer.fit_transform([expected, actual])
|
||||
similarity = cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
||||
else:
|
||||
raise ValueError(f"Unsupported similarity metric: {self.metric}")
|
||||
return {
|
||||
"match": similarity >= self.similarity_threshold,
|
||||
"score": min(similarity * self.weight, self.weight),
|
||||
}
|
||||
632
arcade/arcade/sdk/eval/eval.py
Normal file
632
arcade/arcade/sdk/eval/eval.py
Normal file
|
|
@ -0,0 +1,632 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Use `pip install arcade[evals]` to install the required dependencies for evaluation."
|
||||
)
|
||||
|
||||
from arcade.client.client import Arcade, AsyncArcade
|
||||
from arcade.core.config import config
|
||||
from arcade.sdk.error import WeightError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.sdk.eval.critic import Critic
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpectedToolCall:
|
||||
"""
|
||||
Represents an expected tool call with its name and arguments.
|
||||
|
||||
Attributes:
|
||||
name: The name of the tool.
|
||||
args: A dictionary containing the expected arguments for the tool.
|
||||
"""
|
||||
|
||||
name: str
|
||||
args: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalRubric:
|
||||
"""
|
||||
Defines the rubric for evaluating an AI model's performance on a task.
|
||||
|
||||
Attributes:
|
||||
fail_threshold: The minimum score required to pass the evaluation (between 0.0 and 1.0).
|
||||
warn_threshold: The score threshold for issuing a warning (between 0.0 and 1.0).
|
||||
fail_on_tool_selection: Whether to fail the evaluation if the tool selection is incorrect.
|
||||
fail_on_tool_call_quantity: Whether to fail the evaluation if the number of tool calls is incorrect.
|
||||
tool_selection_weight: The weight assigned to the tool selection score (between 0.0 and 1.0).
|
||||
"""
|
||||
|
||||
fail_threshold: float = 0.8
|
||||
warn_threshold: float = 0.9
|
||||
fail_on_tool_selection: bool = True
|
||||
fail_on_tool_call_quantity: bool = True
|
||||
tool_selection_weight: float = 1.0
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Fail threshold: {self.fail_threshold}\nWarn threshold: {self.warn_threshold}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationResult:
|
||||
"""
|
||||
Represents the result of an evaluation case.
|
||||
|
||||
Attributes:
|
||||
score: The normalized evaluation score (0.0-1.0).
|
||||
passed: Whether the evaluation passed based on the fail_threshold.
|
||||
warning: Whether the evaluation issued a warning based on the warn_threshold.
|
||||
results: A list of dictionaries containing the results for each critic.
|
||||
"""
|
||||
|
||||
score: float = 0.0
|
||||
passed: bool = False
|
||||
warning: bool = False
|
||||
results: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def fail(self) -> bool:
|
||||
return not self.passed and not self.warning
|
||||
|
||||
def add(
|
||||
self,
|
||||
field: str,
|
||||
result: dict[str, Any],
|
||||
weight: float,
|
||||
expected: Any,
|
||||
actual: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Add a critic result to the list of critic results.
|
||||
|
||||
Args:
|
||||
field: The field name for the critic result.
|
||||
result: A dictionary containing the critic result.
|
||||
weight: The weight of the critic.
|
||||
expected: The expected value for the critic.
|
||||
actual: The actual value for the critic.
|
||||
"""
|
||||
self.results.append(
|
||||
{
|
||||
"field": field,
|
||||
**result,
|
||||
"weight": weight,
|
||||
"expected": expected,
|
||||
"actual": actual,
|
||||
}
|
||||
)
|
||||
|
||||
def score_tool_selection(self, expected: str, actual: str, weight: float) -> float:
|
||||
"""
|
||||
Score and record tool selection in results.
|
||||
|
||||
Args:
|
||||
expected: The expected tool name.
|
||||
actual: The actual tool name.
|
||||
weight: The weight for tool selection.
|
||||
|
||||
Returns:
|
||||
The score for the tool selection.
|
||||
"""
|
||||
score = weight if expected == actual else 0.0
|
||||
self.add(
|
||||
"tool_selection",
|
||||
{"match": expected == actual, "score": score},
|
||||
weight,
|
||||
expected,
|
||||
actual,
|
||||
)
|
||||
return score
|
||||
|
||||
def compute_final_score(self, total_weight: float) -> None:
|
||||
"""
|
||||
Compute the final score by normalizing the total score with the total weight.
|
||||
"""
|
||||
total_score = sum(result["score"] for result in self.results)
|
||||
self.score = total_score / total_weight if total_weight > 0 else 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalCase:
|
||||
"""
|
||||
Represents a single evaluation case within an EvalSuite.
|
||||
|
||||
Attributes:
|
||||
name: A descriptive name for this evaluation case.
|
||||
system_message: The system message to be sent to the AI model.
|
||||
user_message: The user input to be sent to the AI model.
|
||||
expected_tool_calls: A list of ExpectedToolCall objects representing the expected tool calls.
|
||||
critics: A list of Critic objects used to evaluate tool arguments.
|
||||
additional_messages: Optional list of additional context messages.
|
||||
rubric: An EvalRubric object defining pass/fail criteria and tool selection behavior.
|
||||
"""
|
||||
|
||||
name: str
|
||||
system_message: str
|
||||
user_message: str
|
||||
expected_tool_calls: list[ExpectedToolCall]
|
||||
critics: list["Critic"]
|
||||
additional_messages: list[dict[str, str]] = field(default_factory=list)
|
||||
rubric: EvalRubric = field(default_factory=EvalRubric)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate_critics()
|
||||
|
||||
def _validate_critics(self) -> None:
|
||||
"""
|
||||
Validate the sum of critic weights.
|
||||
|
||||
Raises:
|
||||
WeightError: If the sum of critic weights exceeds 1.0.
|
||||
"""
|
||||
total_weight = sum(critic.weight for critic in self.critics)
|
||||
if total_weight > 1.0:
|
||||
raise WeightError(f"Sum of critic weights must not exceed 1.0, got {total_weight}")
|
||||
|
||||
for critic in self.critics:
|
||||
if critic.weight < 0.1:
|
||||
raise WeightError(f"Critic weights should be at least 0.1, got {critic.weight}")
|
||||
|
||||
def check_tool_selection_failure(self, actual_tools: list[str]) -> bool:
|
||||
"""
|
||||
Check if tool selection failure should occur.
|
||||
|
||||
Args:
|
||||
actual_tools: The list of actual tool names used.
|
||||
|
||||
Returns:
|
||||
True if tool selection failure should occur, False otherwise.
|
||||
"""
|
||||
expected_tools = [tc.name for tc in self.expected_tool_calls]
|
||||
return self.rubric.fail_on_tool_selection and set(expected_tools) != set(actual_tools)
|
||||
|
||||
def check_tool_call_quantity_failure(self, actual_count: int) -> bool:
|
||||
"""
|
||||
Check if tool call quantity failure should occur.
|
||||
|
||||
Args:
|
||||
actual_count: The number of actual tool calls made.
|
||||
|
||||
Returns:
|
||||
True if tool call quantity failure should occur, False otherwise.
|
||||
"""
|
||||
expected_count = len(self.expected_tool_calls)
|
||||
return self.rubric.fail_on_tool_call_quantity and expected_count != actual_count
|
||||
|
||||
def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> EvaluationResult:
|
||||
"""
|
||||
Evaluate the actual tool calls against the expected tool calls and critics.
|
||||
|
||||
Args:
|
||||
actual_tool_calls: A list of tuples containing the actual tool name and arguments.
|
||||
|
||||
Returns:
|
||||
An EvaluationResult object containing the evaluation results.
|
||||
"""
|
||||
evaluation_result = EvaluationResult()
|
||||
actual_tools = [tool for tool, _ in actual_tool_calls]
|
||||
|
||||
if self.check_tool_selection_failure(actual_tools):
|
||||
evaluation_result.score = 0.0
|
||||
evaluation_result.passed = False
|
||||
evaluation_result.warning = False
|
||||
return evaluation_result
|
||||
|
||||
actual_count = len(actual_tool_calls)
|
||||
if self.check_tool_call_quantity_failure(actual_count):
|
||||
evaluation_result.score = 0.0
|
||||
evaluation_result.passed = False
|
||||
evaluation_result.warning = False
|
||||
return evaluation_result
|
||||
|
||||
# Create a cost matrix for the assignment problem
|
||||
cost_matrix = self._create_cost_matrix(actual_tool_calls)
|
||||
|
||||
# Use the Linear Sum Assignment (LSA) algorithm to find the optimal assignment
|
||||
# The algorithm minimizes the cost of assigning each expected tool call to an actual tool call
|
||||
row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True)
|
||||
|
||||
total_score = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for i, j in zip(row_ind, col_ind):
|
||||
if i < len(self.expected_tool_calls) and j < len(actual_tool_calls):
|
||||
expected = self.expected_tool_calls[i]
|
||||
actual_tool, actual_args = actual_tool_calls[j]
|
||||
|
||||
tool_selection_score = evaluation_result.score_tool_selection(
|
||||
expected.name, actual_tool, self.rubric.tool_selection_weight
|
||||
)
|
||||
total_score += tool_selection_score
|
||||
total_weight += self.rubric.tool_selection_weight
|
||||
|
||||
# Evaluate arguments using critics
|
||||
for critic in self.critics:
|
||||
expected_value = expected.args.get(critic.critic_field)
|
||||
actual_value = actual_args.get(critic.critic_field)
|
||||
if expected_value is not None and actual_value is not None:
|
||||
result = critic.evaluate(expected_value, actual_value)
|
||||
total_score += result["score"]
|
||||
total_weight += critic.weight
|
||||
evaluation_result.add(
|
||||
critic.critic_field, result, critic.weight, expected_value, actual_value
|
||||
)
|
||||
|
||||
# Compute the final score using the method from EvaluationResult
|
||||
evaluation_result.compute_final_score(total_weight)
|
||||
|
||||
# Set the pass/fail status based on the fail_threshold
|
||||
evaluation_result.passed = evaluation_result.score >= self.rubric.fail_threshold
|
||||
|
||||
# Set the warning status based on the warn_threshold
|
||||
evaluation_result.warning = (
|
||||
not evaluation_result.passed and evaluation_result.score >= self.rubric.warn_threshold
|
||||
)
|
||||
|
||||
return evaluation_result
|
||||
|
||||
def _create_cost_matrix(
|
||||
self, actual_tool_calls: list[tuple[str, dict[str, Any]]]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Create a cost matrix for the Hungarian algorithm.
|
||||
|
||||
This method computes the score for each possible pairing of expected and actual tool calls.
|
||||
The resulting matrix is used by the Hungarian algorithm to find the optimal assignment.
|
||||
|
||||
Args:
|
||||
actual_tool_calls: A list of tuples containing the actual tool calls and their arguments.
|
||||
|
||||
Returns:
|
||||
A numpy array representing the cost matrix.
|
||||
"""
|
||||
n = max(len(self.expected_tool_calls), len(actual_tool_calls))
|
||||
cost_matrix = np.zeros((n, n))
|
||||
|
||||
for i, expected in enumerate(self.expected_tool_calls):
|
||||
for j, (actual_tool, actual_args) in enumerate(actual_tool_calls):
|
||||
score = 0.0
|
||||
if expected.name == actual_tool:
|
||||
score += self.rubric.tool_selection_weight
|
||||
|
||||
for critic in self.critics:
|
||||
expected_value = expected.args.get(critic.critic_field)
|
||||
actual_value = actual_args.get(critic.critic_field)
|
||||
if expected_value is not None and actual_value is not None:
|
||||
result = critic.evaluate(expected_value, actual_value)
|
||||
score += result["score"]
|
||||
cost_matrix[i, j] = score
|
||||
|
||||
return cost_matrix
|
||||
|
||||
async def run_async(
|
||||
self, client: AsyncArcade, model: str, tool_names: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run the evaluation case asynchronously.
|
||||
|
||||
Args:
|
||||
client: The AsyncArcade client instance.
|
||||
model: The model to evaluate.
|
||||
tool_names: The list of tool names to use for the evaluation.
|
||||
Returns:
|
||||
A dictionary containing the evaluation result for the case.
|
||||
"""
|
||||
messages = [{"role": "system", "content": self.system_message}]
|
||||
messages.extend(list(self.additional_messages))
|
||||
messages.append({"role": "user", "content": self.user_message})
|
||||
|
||||
response = await client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=messages,
|
||||
tool_choice="auto",
|
||||
tools=tool_names,
|
||||
user="eval_user",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
predicted_args = get_tool_args(response)
|
||||
|
||||
evaluation = self.evaluate(predicted_args)
|
||||
|
||||
result = {
|
||||
"name": self.name,
|
||||
"input": self.user_message,
|
||||
"expected_tool_calls": [
|
||||
{"name": tc.name, "args": tc.args} for tc in self.expected_tool_calls
|
||||
],
|
||||
"predicted_tool_calls": [{"name": tool, "args": args} for tool, args in predicted_args],
|
||||
"evaluation": evaluation,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def run_sync(self, client: Arcade, model: str, tool_names: list[str]) -> dict[str, Any]:
|
||||
"""
|
||||
Run the evaluation case synchronously.
|
||||
|
||||
Args:
|
||||
client: The Arcade client instance.
|
||||
model: The model to evaluate.
|
||||
tool_names: The list of tool names to use for the evaluation.
|
||||
Returns:
|
||||
A dictionary containing the evaluation result for the case.
|
||||
"""
|
||||
messages = [{"role": "system", "content": self.system_message}]
|
||||
messages.extend(list(self.additional_messages))
|
||||
messages.append({"role": "user", "content": self.user_message})
|
||||
|
||||
response = client.chat.completions.create( # type: ignore[call-overload]
|
||||
model=model,
|
||||
messages=messages,
|
||||
tool_choice="auto",
|
||||
tools=tool_names,
|
||||
user="eval_user",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
predicted_args = get_tool_args(response)
|
||||
|
||||
evaluation = self.evaluate(predicted_args)
|
||||
|
||||
result = {
|
||||
"name": self.name,
|
||||
"input": self.user_message,
|
||||
"expected_tool_calls": [
|
||||
{"name": tc.name, "args": tc.args} for tc in self.expected_tool_calls
|
||||
],
|
||||
"predicted_tool_calls": [{"name": tool, "args": args} for tool, args in predicted_args],
|
||||
"evaluation": evaluation,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalSuite:
|
||||
"""
|
||||
A suite for evaluating AI model performance on specific tasks or scenarios.
|
||||
|
||||
EvalSuite manages a collection of EvalCases, each representing a specific test scenario.
|
||||
It provides methods to add cases, register tools, and run evaluations against specified models.
|
||||
|
||||
Attributes:
|
||||
name: The name of the evaluation suite.
|
||||
system_message: The system message to be used for all cases in this suite.
|
||||
catalog: A ToolCatalog object containing registered tools.
|
||||
cases: A list of EvalCase objects representing individual test scenarios.
|
||||
tool_choice: The tool choice mode for the AI model ("auto" or "function").
|
||||
rubric: The evaluation rubric for this case.
|
||||
max_concurrent: Maximum number of concurrent evaluations.
|
||||
"""
|
||||
|
||||
name: str
|
||||
system_message: str
|
||||
catalog: "ToolCatalog"
|
||||
cases: list[EvalCase] = field(default_factory=list)
|
||||
rubric: EvalRubric = field(default_factory=EvalRubric)
|
||||
max_concurrent: int = 1 # Default to sequential execution
|
||||
_client: AsyncArcade | Arcade | None = None
|
||||
|
||||
def initialize_client(self) -> None:
|
||||
"""
|
||||
Initialize the client instance for the EvalSuite.
|
||||
"""
|
||||
if self.max_concurrent > 1:
|
||||
self._client = AsyncArcade(
|
||||
api_key=config.api.key,
|
||||
base_url=config.engine_url,
|
||||
)
|
||||
else:
|
||||
self._client = Arcade(
|
||||
api_key=config.api.key,
|
||||
base_url=config.engine_url,
|
||||
)
|
||||
|
||||
def add_case(
|
||||
self,
|
||||
name: str,
|
||||
user_message: str,
|
||||
expected_tool_calls: list[ExpectedToolCall],
|
||||
critics: list["Critic"],
|
||||
system_message: str | None = None,
|
||||
rubric: EvalRubric | None = None,
|
||||
additional_messages: list[dict[str, str]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a new evaluation case to the suite.
|
||||
|
||||
Args:
|
||||
name: The name of the evaluation case.
|
||||
user_message: The user's input message.
|
||||
system_message: The system message to be sent to the AI model.
|
||||
expected_tool_calls: A list of expected tool calls.
|
||||
critics: List of critics to evaluate the tool arguments.
|
||||
rubric: The evaluation rubric for this case.
|
||||
additional_messages: Optional list of additional messages for context.
|
||||
"""
|
||||
case = EvalCase(
|
||||
name=name,
|
||||
system_message=system_message or self.system_message,
|
||||
user_message=user_message,
|
||||
expected_tool_calls=expected_tool_calls,
|
||||
rubric=rubric or self.rubric,
|
||||
critics=critics,
|
||||
additional_messages=additional_messages or [],
|
||||
)
|
||||
self.cases.append(case)
|
||||
|
||||
def extend_case(
|
||||
self,
|
||||
name: str,
|
||||
user_message: str,
|
||||
system_message: str | None = None,
|
||||
expected_tool_calls: list[ExpectedToolCall] | None = None,
|
||||
rubric: EvalRubric | None = None,
|
||||
critics: list["Critic"] | None = None,
|
||||
additional_messages: list[dict[str, str]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Extend the last added case with new information.
|
||||
|
||||
Args:
|
||||
name: The name of the extended case.
|
||||
user_message: The new user message for this extended case.
|
||||
system_message: The new system message for this extended case.
|
||||
expected_tool_calls: New or updated expected tool calls.
|
||||
rubric: A new rubric (if different from the last case).
|
||||
critics: New critics (if different from the last case).
|
||||
additional_messages: New additional messages (if different from the last case).
|
||||
to be added before the new user message.
|
||||
"""
|
||||
if not self.cases:
|
||||
raise ValueError("No cases to extend. Add a case first.")
|
||||
|
||||
last_case = self.cases[-1]
|
||||
|
||||
# Create a new message list with the previous case's messages and user message
|
||||
new_additional_messages = [
|
||||
*last_case.additional_messages,
|
||||
]
|
||||
if additional_messages:
|
||||
new_additional_messages.extend(additional_messages)
|
||||
|
||||
# Create a new case, copying from the last one and updating fields
|
||||
new_case = EvalCase(
|
||||
name=name,
|
||||
system_message=system_message or last_case.system_message,
|
||||
user_message=user_message,
|
||||
expected_tool_calls=expected_tool_calls or last_case.expected_tool_calls,
|
||||
rubric=rubric or self.rubric,
|
||||
critics=critics or last_case.critics.copy(),
|
||||
additional_messages=new_additional_messages,
|
||||
)
|
||||
|
||||
self.cases.append(new_case)
|
||||
|
||||
async def run_async(self, model: str) -> dict[str, Any]:
|
||||
"""
|
||||
Run the evaluation suite asynchronously.
|
||||
|
||||
Args:
|
||||
model: The model to evaluate.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation results.
|
||||
"""
|
||||
if not self._client:
|
||||
raise ValueError("Client not initialized. Call initialize_client() first.")
|
||||
|
||||
results: dict[str, Any] = {"model": model, "rubric": self.rubric, "cases": []}
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
tool_names = list(self.catalog.tools.keys())
|
||||
|
||||
async def sem_task(case: EvalCase) -> dict[str, Any]:
|
||||
async with semaphore:
|
||||
return await case.run_async(self._client, model, tool_names) # type: ignore[arg-type]
|
||||
|
||||
tasks = [sem_task(case) for case in self.cases]
|
||||
case_results = await asyncio.gather(*tasks)
|
||||
|
||||
results["cases"] = case_results
|
||||
return results
|
||||
|
||||
def run_sync(self, model: str) -> dict[str, Any]:
|
||||
"""
|
||||
Run the evaluation suite synchronously.
|
||||
|
||||
Args:
|
||||
model: The model to evaluate.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation results.
|
||||
"""
|
||||
if not self._client:
|
||||
raise ValueError("Client not initialized. Call initialize_client() first.")
|
||||
|
||||
cases: list[dict[str, Any]] = []
|
||||
results = {"model": model, "rubric": self.rubric, "cases": cases}
|
||||
tool_names = list(self.catalog.tools.keys())
|
||||
for case in self.cases:
|
||||
result = case.run_sync(self._client, model, tool_names) # type: ignore[arg-type]
|
||||
cases.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def run(self, model: str) -> dict[str, Any]:
|
||||
"""
|
||||
Run the evaluation suite.
|
||||
|
||||
Args:
|
||||
model: The model to evaluate.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation results.
|
||||
"""
|
||||
if not self._client:
|
||||
self.initialize_client()
|
||||
|
||||
if self.max_concurrent > 1:
|
||||
# Run asynchronously with concurrency
|
||||
return asyncio.run(self.run_async(model))
|
||||
else:
|
||||
# Run synchronously
|
||||
return self.run_sync(model)
|
||||
|
||||
|
||||
def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
|
||||
"""
|
||||
Returns the tool arguments from the chat completion object.
|
||||
|
||||
Args:
|
||||
chat_completion: The chat completion object.
|
||||
|
||||
Returns:
|
||||
A list of tuples containing the tool name and arguments.
|
||||
"""
|
||||
tool_args_list: list[tuple[str, dict[str, Any]]] = []
|
||||
message = chat_completion.choices[0].message
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_args_list.append(
|
||||
(
|
||||
tool_call.function.name,
|
||||
json.loads(tool_call.function.arguments),
|
||||
)
|
||||
)
|
||||
return tool_args_list
|
||||
|
||||
|
||||
def tool_eval() -> Callable[[Callable], Callable]:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(
|
||||
models: list[str],
|
||||
max_concurrency: int = 1,
|
||||
) -> list[dict[str, Any]]:
|
||||
suite = func()
|
||||
if not isinstance(suite, EvalSuite):
|
||||
raise TypeError("Eval function must return an EvalSuite")
|
||||
suite.max_concurrent = max_concurrency
|
||||
results = []
|
||||
for model in models:
|
||||
result = suite.run(model)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
wrapper.__tool_eval__ = True # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -24,9 +24,12 @@ requests = "^2.26.0" # TODO: is this really needed?
|
|||
openai = "^1.36.0" # TODO: relax to an earlier version that still has what we need
|
||||
pyjwt = "^2.8.0"
|
||||
|
||||
fastapi = { version = "^0.110.0", optional = true }
|
||||
flask = { version = "^3.0.3", optional = true }
|
||||
|
||||
[tool.poetry.group.fastapi.dependencies]
|
||||
fastapi = "^0.110.0"
|
||||
|
||||
[tool.poetry.group.flask.dependencies]
|
||||
flask = "^3.0.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.1.1"
|
||||
|
|
@ -41,16 +44,17 @@ mkdocs = ">=1.5.2"
|
|||
mkdocs-material = ">=9.3.0"
|
||||
mkdocstrings = {extras = ["python"], version = ">=0.23.1"}
|
||||
|
||||
[tool.poetry.extras]
|
||||
fastapi = ["fastapi"]
|
||||
flask = ["flask"]
|
||||
|
||||
[tool.poetry.group.evals.dependencies]
|
||||
scipy = "^1.14.0"
|
||||
numpy = "^2.0.0"
|
||||
scikit-learn = "^1.5.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
arcade = "arcade.cli.main:cli"
|
||||
|
||||
[tool.mypy]
|
||||
files = ["arcade"]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
|
|
@ -58,6 +62,7 @@ check_untyped_defs = "True"
|
|||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
|
|
|||
|
|
@ -40,9 +40,9 @@ TOOL_DEFINITION_DATA = {
|
|||
"input_schema": {"type": "object", "properties": {"n_emails": {"type": "integer"}}},
|
||||
"output_schema": {"type": "array", "items": {"type": "string"}},
|
||||
"version": "0.1.0",
|
||||
"inputs": {"parameters": []}, # Update this line
|
||||
"inputs": {"parameters": []},
|
||||
"output": {},
|
||||
"requirements": {"auth_requirements": []}, # Update this line
|
||||
"requirements": {"auth_requirements": []},
|
||||
}
|
||||
|
||||
TOOL_AUTHORIZE_RESPONSE_DATA = {
|
||||
|
|
|
|||
342
arcade/tests/sdk/test_eval.py
Normal file
342
arcade/tests/sdk/test_eval.py
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
import pytest
|
||||
|
||||
from arcade.sdk.error import WeightError
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
ExpectedToolCall,
|
||||
NumericCritic,
|
||||
SimilarityCritic,
|
||||
)
|
||||
from arcade.sdk.eval.eval import EvalCase, EvaluationResult
|
||||
|
||||
# Test BinaryCritic.evaluate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected, actual, weight, expected_match, expected_score",
|
||||
[
|
||||
("value", "value", 1.0, True, 1.0),
|
||||
("value", "different", 1.0, False, 0.0),
|
||||
(10, 10, 0.5, True, 0.5),
|
||||
(10, 20, 0.5, False, 0.0),
|
||||
],
|
||||
)
|
||||
def test_binary_critic_evaluate(expected, actual, weight, expected_match, expected_score):
|
||||
"""
|
||||
Test the BinaryCritic's evaluate method to ensure it correctly computes
|
||||
the match and score based on expected and actual values.
|
||||
"""
|
||||
critic = BinaryCritic(critic_field="test_field", weight=weight)
|
||||
result = critic.evaluate(expected=expected, actual=actual)
|
||||
assert result["match"] == expected_match
|
||||
assert result["score"] == expected_score
|
||||
|
||||
|
||||
# Test NumericCritic.evaluate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected, actual, value_range, weight, match_threshold, expected_match, expected_score",
|
||||
[
|
||||
(5, 5, (0, 10), 1.0, 0.8, True, 1.0),
|
||||
(5, 6, (0, 10), 1.0, 0.8, True, 0.9),
|
||||
(0, 10, (0, 10), 1.0, 0.8, False, 0.0),
|
||||
(2, 8, (0, 10), 1.0, 0.5, False, 0.4),
|
||||
(50, 60, (0, 100), 0.5, 0.9, True, 0.45),
|
||||
],
|
||||
)
|
||||
def test_numeric_critic_evaluate(
|
||||
expected, actual, value_range, weight, match_threshold, expected_match, expected_score
|
||||
):
|
||||
"""
|
||||
Test the NumericCritic's evaluate method to ensure it calculates
|
||||
the correct score based on the proportion of the difference between
|
||||
expected and actual values within a specified range.
|
||||
"""
|
||||
critic = NumericCritic(
|
||||
critic_field="number",
|
||||
weight=weight,
|
||||
value_range=value_range,
|
||||
match_threshold=match_threshold,
|
||||
)
|
||||
result = critic.evaluate(expected=expected, actual=actual)
|
||||
assert result["match"] == expected_match
|
||||
assert pytest.approx(result["score"], 0.01) == expected_score
|
||||
|
||||
|
||||
# Test SimilarityCritic.evaluate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected, actual, weight, similarity_threshold, expected_match, min_expected_score",
|
||||
[
|
||||
("hello world", "hello world", 1.0, 0.8, True, 1.0),
|
||||
("hello world", "hello", 1.0, 0.8, False, 0.0),
|
||||
("The quick brown fox", "The quick brown fox jumps over the lazy dog", 1.0, 0.5, True, 0.5),
|
||||
("data science", "machine learning", 0.5, 0.3, False, 0.0),
|
||||
],
|
||||
)
|
||||
def test_similarity_critic_evaluate(
|
||||
expected, actual, weight, similarity_threshold, expected_match, min_expected_score
|
||||
):
|
||||
"""
|
||||
Test the SimilarityCritic's evaluate method to ensure it computes
|
||||
the similarity score between expected and actual strings and determines
|
||||
the match correctly based on the similarity threshold.
|
||||
"""
|
||||
critic = SimilarityCritic(
|
||||
critic_field="text",
|
||||
weight=weight,
|
||||
similarity_threshold=similarity_threshold,
|
||||
)
|
||||
result = critic.evaluate(expected=expected, actual=actual)
|
||||
assert result["match"] == expected_match
|
||||
assert result["score"] >= min_expected_score
|
||||
assert result["score"] >= 0.0
|
||||
assert result["score"] <= weight + 1e-6 # Allow a small epsilon for floating-point comparison
|
||||
|
||||
|
||||
# Test EvaluationResult accumulation and pass/fail logic
|
||||
|
||||
|
||||
def test_evaluation_result_accumulation():
|
||||
"""
|
||||
Test that EvaluationResult correctly accumulates scores and determines
|
||||
pass/fail status based on thresholds.
|
||||
"""
|
||||
evaluation = EvaluationResult()
|
||||
evaluation.add(
|
||||
field="field1",
|
||||
result={"match": True, "score": 0.8},
|
||||
weight=1.0,
|
||||
expected="expected_value",
|
||||
actual="actual_value",
|
||||
)
|
||||
evaluation.add(
|
||||
field="field2",
|
||||
result={"match": False, "score": 0.0},
|
||||
weight=0.5,
|
||||
expected="expected_value",
|
||||
actual="actual_value",
|
||||
)
|
||||
total_weight = 1.5
|
||||
expected_score = (0.8 * 1.0 + 0.0 * 0.5) / total_weight
|
||||
evaluation.compute_final_score(total_weight)
|
||||
assert evaluation.score == expected_score
|
||||
|
||||
|
||||
# Test EvalCase.evaluate()
|
||||
|
||||
|
||||
def test_eval_case_evaluate():
|
||||
"""
|
||||
Test EvalCase's evaluate method to ensure it calculates the overall score
|
||||
correctly based on tool selection and critics, and applies the rubric's
|
||||
thresholds to determine pass/fail/warning status.
|
||||
"""
|
||||
# Define expected tool calls and actual tool calls
|
||||
expected_tool_calls = [
|
||||
ExpectedToolCall(name="ToolA", args={"param": "value1"}),
|
||||
ExpectedToolCall(name="ToolB", args={"param": "value2"}),
|
||||
]
|
||||
actual_tool_calls = [
|
||||
("ToolA", {"param": "value1"}),
|
||||
("ToolB", {"param": "wrong_value"}),
|
||||
]
|
||||
|
||||
# Define critics
|
||||
critics = [
|
||||
BinaryCritic(critic_field="param", weight=1.0),
|
||||
]
|
||||
|
||||
# Create EvalCase with a rubric
|
||||
case = EvalCase(
|
||||
name="TestCase",
|
||||
system_message="System message",
|
||||
user_message="User message",
|
||||
expected_tool_calls=expected_tool_calls,
|
||||
critics=critics,
|
||||
rubric=EvalRubric(fail_threshold=0.75, warn_threshold=0.9, tool_selection_weight=1.0),
|
||||
)
|
||||
|
||||
# Evaluate the case
|
||||
result = case.evaluate(actual_tool_calls)
|
||||
|
||||
# Expected calculations:
|
||||
# - Tool selection score should be 2 * 1.0 = 2.0 (both tools are correct)
|
||||
# - First critic score: match (1.0)
|
||||
# - Second critic score: no match (0.0)
|
||||
# - Total critic score: 1.0 + 0.0 = 1.0
|
||||
# - Total weight: tool selection (2.0) + critics (2.0) = 4.0
|
||||
# - Total score: (2.0 + 1.0) / 4.0 = 0.75
|
||||
|
||||
assert result.score == 0.75
|
||||
assert result.passed is True
|
||||
|
||||
|
||||
# Test EvalCase with mismatched tool calls
|
||||
|
||||
|
||||
def test_eval_case_evaluate_mismatched_tools():
|
||||
"""
|
||||
Test EvalCase's evaluate method when the actual tool calls do not match
|
||||
the expected tool calls to ensure tool selection scoring is correct.
|
||||
"""
|
||||
expected_tool_calls = [
|
||||
ExpectedToolCall(name="ToolA", args={"param": "value"}),
|
||||
]
|
||||
actual_tool_calls = [
|
||||
("ToolB", {"param": "value"}),
|
||||
]
|
||||
|
||||
critics = [BinaryCritic(critic_field="param", weight=1.0)]
|
||||
|
||||
case = EvalCase(
|
||||
name="TestCase",
|
||||
system_message="",
|
||||
user_message="",
|
||||
expected_tool_calls=expected_tool_calls,
|
||||
critics=critics,
|
||||
rubric=EvalRubric(tool_selection_weight=1.0),
|
||||
)
|
||||
|
||||
result = case.evaluate(actual_tool_calls)
|
||||
|
||||
# Tool selection score should be 0.0 since the tools don't match
|
||||
# Critic is not evaluated since the tool selection failed
|
||||
# Total score: 0.0
|
||||
|
||||
assert result.score == 0.0
|
||||
assert result.passed is False
|
||||
|
||||
|
||||
# Test EvalCase with multiple critics and weights
|
||||
|
||||
|
||||
def test_eval_case_multiple_critics():
|
||||
"""
|
||||
Test EvalCase's evaluate method with multiple critics having different weights
|
||||
to ensure individual critic scores are correctly combined into the total score.
|
||||
"""
|
||||
expected_tool_calls = [
|
||||
ExpectedToolCall(name="ToolA", args={"param1": "value1", "param2": "value2"}),
|
||||
]
|
||||
actual_tool_calls = [
|
||||
("ToolA", {"param1": "value1", "param2": "wrong_value"}),
|
||||
]
|
||||
|
||||
critics = [
|
||||
BinaryCritic(critic_field="param1", weight=0.6),
|
||||
SimilarityCritic(critic_field="param2", weight=0.4, similarity_threshold=0.8),
|
||||
]
|
||||
|
||||
case = EvalCase(
|
||||
name="TestCase",
|
||||
system_message="",
|
||||
user_message="",
|
||||
expected_tool_calls=expected_tool_calls,
|
||||
critics=critics,
|
||||
rubric=EvalRubric(fail_threshold=0.7),
|
||||
)
|
||||
|
||||
result = case.evaluate(actual_tool_calls)
|
||||
|
||||
# Tool selection score: 1.0
|
||||
# Critic scores:
|
||||
# - param1: match (score 0.6)
|
||||
# - param2: likely not match (score ~0.0)
|
||||
# Total score: (1.0 + 0.6 + 0.0) / (1.0 + 0.6 + 0.4) = 1.6 / 2.0 = 0.8
|
||||
|
||||
assert pytest.approx(result.score, 0.01) == 0.8
|
||||
assert result.passed
|
||||
|
||||
|
||||
# Test EvalCase with missing expected and actual values in args
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected_args, actual_args, expected_score",
|
||||
[
|
||||
({"param": "value"}, {}, 1.0), # Missing actual value
|
||||
({}, {"param": "value"}, 1.0), # Missing expected value
|
||||
({"param": "value"}, {"param": "value"}, 2.0), # Both values present
|
||||
],
|
||||
)
|
||||
def test_eval_case_missing_values(expected_args, actual_args, expected_score):
|
||||
"""
|
||||
Test that when either expected or actual values are missing for a critic,
|
||||
the critic evaluation is skipped, and the total score is computed accordingly.
|
||||
"""
|
||||
expected_tool_calls = [ExpectedToolCall(name="ToolA", args=expected_args)]
|
||||
actual_tool_calls = [("ToolA", actual_args)]
|
||||
|
||||
critics = [BinaryCritic(critic_field="param", weight=1.0)]
|
||||
|
||||
case = EvalCase(
|
||||
name="TestCase",
|
||||
system_message="",
|
||||
user_message="",
|
||||
expected_tool_calls=expected_tool_calls,
|
||||
critics=critics,
|
||||
rubric=EvalRubric(tool_selection_weight=1.0),
|
||||
)
|
||||
|
||||
result = case.evaluate(actual_tool_calls)
|
||||
|
||||
# If critic is skipped, only tool selection score is counted
|
||||
# Otherwise, tool selection + critic score
|
||||
total_weight = 1.0 # At least tool selection weight
|
||||
if "param" in expected_args and "param" in actual_args:
|
||||
total_weight += 1.0 # Critic weight
|
||||
|
||||
expected_total_score = expected_score / total_weight
|
||||
|
||||
assert result.score == expected_total_score
|
||||
|
||||
|
||||
# Test that WeightError is raised for invalid critic weights
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"critic_class, weight",
|
||||
[
|
||||
(BinaryCritic, -0.1),
|
||||
(BinaryCritic, 1.1),
|
||||
(NumericCritic, -0.5),
|
||||
(SimilarityCritic, 1.5),
|
||||
],
|
||||
)
|
||||
def test_critic_invalid_weight(critic_class, weight):
|
||||
"""
|
||||
Test that initializing a critic with an invalid weight raises a WeightError.
|
||||
"""
|
||||
with pytest.raises(WeightError):
|
||||
if critic_class == NumericCritic:
|
||||
critic_class(critic_field="test_field", weight=weight, value_range=(0, 1))
|
||||
elif critic_class == SimilarityCritic:
|
||||
critic_class(critic_field="test_field", weight=weight)
|
||||
else:
|
||||
critic_class(critic_field="test_field", weight=weight)
|
||||
|
||||
|
||||
# Test NumericCritic with invalid value range
|
||||
|
||||
|
||||
def test_numeric_critic_invalid_range():
|
||||
"""
|
||||
Test that initializing a NumericCritic with an invalid value range raises a ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
NumericCritic(critic_field="number", weight=1.0, value_range=(10, 0)) # Invalid range
|
||||
|
||||
|
||||
# Test SimilarityCritic with unsupported metric
|
||||
|
||||
|
||||
def test_similarity_critic_unsupported_metric():
|
||||
"""
|
||||
Test that initializing a SimilarityCritic with an unsupported metric raises a ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
SimilarityCritic(critic_field="text", weight=1.0, metric="unsupported_metric")
|
||||
|
|
@ -133,6 +133,13 @@ def func_with_optional_param_with_default_value(
|
|||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with an optional input parameter with bar syntax")
|
||||
def func_with_optional_param_with_bar_syntax(
|
||||
param1: Annotated[str | None, "First param"] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with multiple parameters, some with default values")
|
||||
def func_with_mixed_params(
|
||||
context: ToolContext,
|
||||
|
|
@ -456,6 +463,26 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
},
|
||||
id="func_with_optional_param_with_default_value",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_optional_param_with_bar_syntax,
|
||||
{
|
||||
"inputs": ToolInputs(
|
||||
parameters=[
|
||||
InputParameter(
|
||||
name="param1",
|
||||
description="First param",
|
||||
inferrable=True,
|
||||
required=False, # Because of Optional[str]
|
||||
value_schema=ValueSchema(val_type="string", enum=None),
|
||||
)
|
||||
]
|
||||
),
|
||||
"output": ToolOutput(
|
||||
available_modes=["null"], description="No description provided."
|
||||
),
|
||||
},
|
||||
id="func_with_optional_param_with_bar_syntax",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_mixed_params,
|
||||
{
|
||||
|
|
|
|||
44
examples/modal-deploy.py
Normal file
44
examples/modal-deploy.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import os
|
||||
|
||||
from modal import App, Image, asgi_app
|
||||
|
||||
os.environ["WORK_DIR"] = "/root"
|
||||
|
||||
# Define the FastAPI app
|
||||
app = App("arcade-ai-actor")
|
||||
|
||||
|
||||
image = (
|
||||
Image.debian_slim()
|
||||
.copy_local_dir("./dist", "/root/dist")
|
||||
.pip_install("/root/dist/arcade_ai-0.1.0-py3-none-any.whl")
|
||||
.pip_install("/root/dist/arcade_gmail-0.1.0-py3-none-any.whl")
|
||||
.pip_install("/root/dist/arcade_websearch-0.1.0-py3-none-any.whl")
|
||||
.pip_install("/root/dist/arcade_github-0.1.0-py3-none-any.whl")
|
||||
.pip_install("/root/dist/arcade_slack-0.1.0-py3-none-any.whl")
|
||||
.pip_install("fastapi>=0.110.0")
|
||||
.pip_install("uvicorn>=0.24.0")
|
||||
.pip_install("pydantic>=2.7.0")
|
||||
.copy_local_file("./arcade.toml", "/root/arcade.toml")
|
||||
)
|
||||
|
||||
|
||||
@app.function(image=image)
|
||||
@asgi_app()
|
||||
def fastapi_app():
|
||||
from fastapi import FastAPI
|
||||
|
||||
from arcade.actor.fastapi.actor import FastAPIActor
|
||||
from arcade.core.toolkit import Toolkit
|
||||
|
||||
web_app = FastAPI()
|
||||
|
||||
# Initialize app and Arcade FastAPIActor
|
||||
actor = FastAPIActor(web_app)
|
||||
|
||||
# Register toolkits we've installed
|
||||
toolkits = Toolkit.find_all_arcade_toolkits()
|
||||
for toolkit in toolkits:
|
||||
actor.register_toolkit(toolkit)
|
||||
|
||||
return web_app
|
||||
133
toolkits/gmail/evals/eval_gmail_tools.py
Normal file
133
toolkits/gmail/evals/eval_gmail_tools.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
from arcade_gmail.tools.gmail import (
|
||||
DateRange,
|
||||
get_emails,
|
||||
search_emails_by_header,
|
||||
write_draft,
|
||||
)
|
||||
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
NumericCritic,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.7,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
@tool_eval("gpt-3.5-turbo")
|
||||
def gmail_eval_suite():
|
||||
suite = EvalSuite(
|
||||
name="Gmail Tools Evaluation",
|
||||
system="You are an AI assistant with access to Gmail tools. Use them to help the user with their email-related tasks.",
|
||||
)
|
||||
|
||||
# Register the Gmail tools
|
||||
suite.register_tool(write_draft)
|
||||
suite.register_tool(search_emails_by_header)
|
||||
suite.register_tool(get_emails)
|
||||
|
||||
# Write Draft Scenarios
|
||||
suite.add_case(
|
||||
name="Write Draft with specified recipient, subject, and body",
|
||||
user_message="Draft and email to john@example.com asking if we can meet tomorrow at 2 PM",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="WriteDraft",
|
||||
args={
|
||||
"recipient": "john@example.com",
|
||||
"subject": "Meeting Tomorrow",
|
||||
"body": "Hi John, Can we meet tomorrow at 2 PM? Thanks, Alice",
|
||||
},
|
||||
)
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="recipient", weight=0.5),
|
||||
SimilarityCritic(critic_field="subject", weight=0.2),
|
||||
SimilarityCritic(critic_field="body", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
# Search Emails by Header Scenarios
|
||||
suite.add_case(
|
||||
name="Search for emails from a specific sender and time period",
|
||||
user_message="Find emails from alice@example.com sent last week",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SearchEmailsByHeader",
|
||||
args={
|
||||
"sender": "alice@example.com",
|
||||
"date_range": DateRange.LAST_7_DAYS.value,
|
||||
"limit": 25,
|
||||
},
|
||||
)
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="sender", weight=0.5),
|
||||
BinaryCritic(critic_field="date_range", weight=0.4),
|
||||
NumericCritic(critic_field="limit", weight=0.1, value_range=(1, 100)),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Search by subject and date range",
|
||||
user_message="Search for emails with 'Urgent' in the subject from the last 30 days",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SearchEmailsByHeader",
|
||||
args={
|
||||
"subject": "Urgent",
|
||||
"date_range": DateRange.LAST_30_DAYS.value,
|
||||
"limit": 25,
|
||||
},
|
||||
)
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="subject", weight=0.4),
|
||||
BinaryCritic(critic_field="date_range", weight=0.4),
|
||||
NumericCritic(critic_field="limit", weight=0.2, value_range=(1, 100)),
|
||||
],
|
||||
)
|
||||
|
||||
suite.extend_case(
|
||||
name="Followup search by subject and date range",
|
||||
user_message="show me more of those",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SearchEmailsByHeader",
|
||||
args={
|
||||
"subject": "Urgent",
|
||||
"date_range": DateRange.LAST_30_DAYS.value,
|
||||
"limit": 50,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Retrieve specific number of emails",
|
||||
user_message="Retrieve the last 10 emails in my inbox",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="GetEmails",
|
||||
args={"n_emails": 10},
|
||||
)
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="n_emails", weight=0.8),
|
||||
NumericCritic(critic_field="n_emails", weight=0.2, value_range=(1, 20)),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
191
toolkits/slack/evals/eval_slack_messaging.py
Normal file
191
toolkits/slack/evals/eval_slack_messaging.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
from arcade_slack.tools.chat import send_dm_to_user, send_message_to_channel
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_tool(send_dm_to_user)
|
||||
catalog.add_tool(send_message_to_channel)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def slack_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for Slack messaging tools."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Messaging Tools Evaluation",
|
||||
system_message="You are an AI assistant that can send direct messages and post messages to channels in Slack using the provided tools.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
# Send DM to User Scenarios
|
||||
suite.add_case(
|
||||
name="Send DM to user with clear username",
|
||||
user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
"user_name": "johndoe",
|
||||
"message": "Hello, can we meet at 3 PM?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send DM with ambiguous username",
|
||||
user_message="Message John about the project deadline",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
"user_name": "john",
|
||||
"message": "Hi John, I wanted to check about the project deadline. Can you provide an update?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="user_name", weight=0.4),
|
||||
SimilarityCritic(critic_field="message", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send DM with username in different format",
|
||||
user_message="DM Jane.Doe to reschedule our meeting",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
"user_name": "jane.doe",
|
||||
"message": "Hi Jane, I need to reschedule our meeting. When are you available?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
# Send Message to Channel Scenarios
|
||||
suite.add_case(
|
||||
name="Send message to channel with clear name",
|
||||
user_message="Post 'The new feature is now live!' in the #announcements channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendMessageToChannel",
|
||||
args={
|
||||
"channel_name": "announcements",
|
||||
"message": "The new feature is now live!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="channel_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send message to channel with ambiguous name",
|
||||
user_message="Inform the engineering team about the upcoming maintenance in the general channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendMessageToChannel",
|
||||
args={
|
||||
"channel_name": "engineering",
|
||||
"message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="channel_name", weight=0.4),
|
||||
SimilarityCritic(critic_field="message", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
# Adversarial Scenarios
|
||||
suite.add_case(
|
||||
name="Ambiguous between DM and channel message",
|
||||
user_message="Send 'Great job on the presentation!' to the team",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendMessageToChannel",
|
||||
args={
|
||||
"channel_name": "general",
|
||||
"message": "Great job on the presentation!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="channel_name", weight=0.4),
|
||||
SimilarityCritic(critic_field="message", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
# Multiple recipients in DM request
|
||||
suite.add_case(
|
||||
name="Multiple recipients in DM request",
|
||||
user_message="Send a DM to Alice and Bob about pushing the meeting tomorrow. I have to much work to do.",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
"user_name": "alice",
|
||||
"message": "Hi Alice, about our meeting tomorrow, let's reschedule? I am swamped with work.",
|
||||
},
|
||||
),
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
"user_name": "bob",
|
||||
"message": "Hi Bob, about our meeting tomorrow, let's reschedule? I am swamped with work.",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="user_name", weight=0.4),
|
||||
SimilarityCritic(critic_field="message", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Channel name similar to username",
|
||||
user_message="Post 'sounds great!' in john-project channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendMessageToChannel",
|
||||
args={
|
||||
"channel_name": "john-project",
|
||||
"message": "Sounds great!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="channel_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
Loading…
Reference in a new issue