arcade-mcp/arcade/arcade/sdk/eval/eval.py
Nate Barbettini 9d00295e33
Replace arcade.client with arcadepy (#119)
Closes: https://app.clickup.com/t/86b2k2962

---------

Co-authored-by: sdreyer <sterling@arcade-ai.com>
2024-10-23 15:29:02 -07:00

666 lines
24 KiB
Python

import asyncio
import functools
import inspect
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from arcade.core.config_model import Config
from arcade.core.schema import TOOL_NAME_SEPARATOR
try:
import numpy as np
from scipy.optimize import linear_sum_assignment
except ImportError:
raise ImportError(
"Use `pip install arcade-ai[evals]` to install the required dependencies for evaluation."
)
from openai import AsyncOpenAI
from arcade.sdk.error import WeightError
if TYPE_CHECKING:
from arcade.core.catalog import ToolCatalog
from arcade.sdk.eval.critic import Critic
@dataclass
class ExpectedToolCall:
"""
Represents an expected tool call with its name and arguments.
Attributes:
name: The name of the tool.
args: A dictionary containing the expected arguments for the tool.
"""
name: str
args: dict[str, Any]
@dataclass
class EvalRubric:
"""
Defines the rubric for evaluating an AI model's performance on a task.
Attributes:
fail_threshold: The minimum score required to pass the evaluation (between 0.0 and 1.0).
warn_threshold: The score threshold for issuing a warning (between 0.0 and 1.0).
fail_on_tool_selection: Whether to fail the evaluation if the tool selection is incorrect.
fail_on_tool_call_quantity: Whether to fail the evaluation if the number of tool calls is incorrect.
tool_selection_weight: The weight assigned to the tool selection score (between 0.0 and 1.0).
"""
fail_threshold: float = 0.8
warn_threshold: float = 0.9
fail_on_tool_selection: bool = True
fail_on_tool_call_quantity: bool = True
tool_selection_weight: float = 1.0
def __str__(self) -> str:
return f"Fail threshold: {self.fail_threshold}\nWarn threshold: {self.warn_threshold}\n"
@dataclass
class EvaluationResult:
"""
Represents the result of an evaluation case.
Attributes:
score: The normalized evaluation score (0.0-1.0).
passed: Whether the evaluation passed based on the fail_threshold.
warning: Whether the evaluation issued a warning based on the warn_threshold.
results: A list of dictionaries containing the results for each critic.
failure_reason: If the evaluation failed completely due to settings in the rubric,
this field contains the reason for failure.
"""
score: float = 0.0
passed: bool = False
warning: bool = False
results: list[dict[str, Any]] = field(default_factory=list)
failure_reason: str | None = None
@property
def fail(self) -> bool:
return not self.passed and not self.warning
def add(
self,
field: str,
result: dict[str, Any],
weight: float,
expected: Any,
actual: Any,
) -> None:
"""
Add a critic result to the list of critic results.
Args:
field: The field name for the critic result.
result: A dictionary containing the critic result.
weight: The weight of the critic.
expected: The expected value for the critic.
actual: The actual value for the critic.
"""
self.results.append({
"field": field,
**result,
"weight": weight,
"expected": expected,
"actual": actual,
})
def score_tool_selection(self, expected: str, actual: str, weight: float) -> float:
"""
Score and record tool selection in results.
Args:
expected: The expected tool name.
actual: The actual tool name.
weight: The weight for tool selection.
Returns:
The score for the tool selection.
"""
score = weight if compare_tool_name(expected, actual) else 0.0
self.add(
"tool_selection",
{"match": compare_tool_name(expected, actual), "score": score},
weight,
expected,
actual,
)
return score
def compute_final_score(self, total_weight: float) -> None:
"""
Compute the final score by normalizing the total score with the total weight.
"""
total_score = sum(result["score"] for result in self.results)
self.score = total_score / total_weight if total_weight > 0 else 0.0
@dataclass
class EvalCase:
"""
Represents a single evaluation case within an EvalSuite.
Attributes:
name: A descriptive name for this evaluation case.
system_message: The system message to be sent to the AI model.
user_message: The user input to be sent to the AI model.
expected_tool_calls: A list of ExpectedToolCall objects representing the expected tool calls.
critics: A list of Critic objects used to evaluate tool arguments.
additional_messages: Optional list of additional context messages.
rubric: An EvalRubric object defining pass/fail criteria and tool selection behavior.
"""
name: str
system_message: str
user_message: str
expected_tool_calls: list[ExpectedToolCall]
critics: list["Critic"] | None = None
additional_messages: list[dict[str, str]] = field(default_factory=list)
rubric: EvalRubric = field(default_factory=EvalRubric)
def __post_init__(self) -> None:
if self.critics is not None:
self._validate_critics()
else:
# if no critics are provided, set to empty list
self.critics = []
def _validate_critics(self) -> None:
"""
Validate the sum of critic weights.
Raises:
WeightError: If the sum of critic weights exceeds 1.0.
"""
if not self.critics:
return
total_weight = sum(critic.weight for critic in self.critics)
if total_weight > 1.0:
raise WeightError(f"Sum of critic weights must not exceed 1.0, got {total_weight}")
for critic in self.critics:
if critic.weight < 0.1:
raise WeightError(f"Critic weights should be at least 0.1, got {critic.weight}")
def check_tool_selection_failure(self, actual_tools: list[str]) -> bool:
"""
Check if tool selection failure should occur.
Args:
actual_tools: The list of actual tool names used.
Returns:
True if tool selection failure should occur, False otherwise.
"""
sorted_expected_tools = sorted([tc.name for tc in self.expected_tool_calls])
sorted_actual_tools = sorted(actual_tools)
return self.rubric.fail_on_tool_selection and not all(
compare_tool_name(expected, actual)
for expected, actual in zip(sorted_expected_tools, sorted_actual_tools)
)
def check_tool_call_quantity_failure(self, actual_count: int) -> bool:
"""
Check if tool call quantity failure should occur.
Args:
actual_count: The number of actual tool calls made.
Returns:
True if tool call quantity failure should occur, False otherwise.
"""
expected_count = len(self.expected_tool_calls)
return self.rubric.fail_on_tool_call_quantity and expected_count != actual_count
def evaluate(
self,
actual_tool_calls: list[tuple[str, dict[str, Any]]],
) -> EvaluationResult:
"""
Evaluate the actual tool calls against the expected tool calls and critics.
Args:
actual_tool_calls: A list of tuples containing the actual tool name and arguments.
Returns:
An EvaluationResult object containing the evaluation results.
"""
evaluation_result = EvaluationResult()
actual_tools = [tool_name for tool_name, _ in actual_tool_calls]
actual_count = len(actual_tool_calls)
if self.check_tool_call_quantity_failure(actual_count):
evaluation_result.score = 0.0
evaluation_result.passed = False
expected_count = len(self.expected_tool_calls)
expected_tool_names = ", ".join(
tool_call.name for tool_call in self.expected_tool_calls
)
evaluation_result.failure_reason = (
f"Expected {expected_count} tool call(s), but got {actual_count}. "
+ f"\nExpected tool calls: {expected_tool_names}.\nActual tool calls: {', '.join(actual_tools)}"
)
return evaluation_result
if not self.expected_tool_calls and not actual_tools:
evaluation_result.score = 1.0
evaluation_result.passed = True
return evaluation_result
if self.check_tool_selection_failure(actual_tools):
evaluation_result.score = 0.0
evaluation_result.passed = False
expected_tools = [tc.name for tc in self.expected_tool_calls]
evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}"
return evaluation_result
if not self.critics:
evaluation_result.score = 1.0
evaluation_result.passed = True
return evaluation_result
# Create a cost matrix for the assignment problem
cost_matrix = self._create_cost_matrix(actual_tool_calls, self.expected_tool_calls)
# Use the Linear Sum Assignment algorithm to find the optimal assignment
row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True)
total_score = 0.0
total_weight = 0.0
for i, j in zip(row_ind, col_ind):
if i < len(self.expected_tool_calls) and j < len(actual_tool_calls):
expected = self.expected_tool_calls[i]
actual_name, actual_args = actual_tool_calls[j]
# Tool selection
tool_selection_score = evaluation_result.score_tool_selection(
expected.name, actual_name, self.rubric.tool_selection_weight
)
total_score += tool_selection_score
total_weight += self.rubric.tool_selection_weight
# Evaluate arguments using critics
for critic in self.critics:
expected_value = expected.args.get(critic.critic_field)
actual_value = actual_args.get(critic.critic_field)
try:
result = critic.evaluate(expected_value, actual_value)
total_score += result["score"]
total_weight += critic.weight
evaluation_result.add(
critic.critic_field,
result,
critic.weight,
expected_value,
actual_value,
)
except Exception as e:
# TODO: log or console
print(f"Critic evaluation failed for field '{critic.critic_field}': {e}")
evaluation_result.add(
critic.critic_field,
{"match": False, "score": 0.0},
critic.weight,
expected_value,
actual_value,
)
continue
# Compute the final score
evaluation_result.compute_final_score(total_weight)
# Set pass/fail and warning status
evaluation_result.passed = evaluation_result.score >= self.rubric.fail_threshold
evaluation_result.warning = (
not evaluation_result.passed and evaluation_result.score >= self.rubric.warn_threshold
)
return evaluation_result
def _create_cost_matrix(
self,
actual_tool_calls: list[tuple[str, dict[str, Any]]],
expected_tool_calls: list[ExpectedToolCall],
) -> np.ndarray:
"""
Create a cost matrix for the assignment problem.
Args:
actual_tool_calls: A list of tuples of actual tool calls.
expected_tool_calls: A list of ExpectedToolCall instances.
Returns:
A numpy array representing the cost matrix.
"""
num_expected = len(expected_tool_calls)
num_actual = len(actual_tool_calls)
n = max(num_expected, num_actual)
cost_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
if i < num_expected and j < num_actual:
expected = expected_tool_calls[i]
actual_name, actual_args = actual_tool_calls[j]
score = 0.0
# Tool selection
if compare_tool_name(expected.name, actual_name):
score += self.rubric.tool_selection_weight
# Critics evaluation
for critic in self.critics: # type: ignore[union-attr]
expected_value = expected.args.get(critic.critic_field)
actual_value = actual_args.get(critic.critic_field)
if expected_value is not None and actual_value is not None:
try:
result = critic.evaluate(expected_value, actual_value)
score += result.get("score", 0.0)
except Exception as e:
print(
f"Critic evaluation failed for field '{critic.critic_field}': {e}"
)
cost_matrix[i, j] = score
return cost_matrix
@dataclass
class EvalSuite:
"""
A suite for evaluating AI model performance on specific tasks or scenarios.
EvalSuite manages a collection of EvalCases, each representing a specific test scenario.
It provides methods to add cases, register tools, and run evaluations against specified models.
Attributes:
name: The name of the evaluation suite.
system_message: The system message to be used for all cases in this suite.
catalog: A ToolCatalog object containing registered tools.
cases: A list of EvalCase objects representing individual test scenarios.
rubric: The evaluation rubric for this case.
max_concurrent: Maximum number of concurrent evaluations.
"""
name: str
system_message: str
catalog: "ToolCatalog"
cases: list[EvalCase] = field(default_factory=list)
rubric: EvalRubric = field(default_factory=EvalRubric)
max_concurrent: int = 1
def add_case(
self,
name: str,
user_message: str,
expected_tool_calls: list[tuple[Callable, dict[str, Any]]],
critics: list["Critic"] | None = None,
system_message: str | None = None,
rubric: EvalRubric | None = None,
additional_messages: list[dict[str, str]] | None = None,
) -> None:
"""
Add a new evaluation case to the suite.
Args:
name: The name of the evaluation case.
user_message: The user's input message.
expected_tool_calls: A list of expected tool calls as tuples of (function, args).
critics: List of critics to evaluate the tool arguments.
system_message: The system message to be used.
rubric: The evaluation rubric for this case.
additional_messages: Optional list of additional messages for context.
"""
expected = []
for func, args in expected_tool_calls:
# Fill in default arguments here
args_with_defaults = self._fill_args_with_defaults(func, args)
tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name())
expected.append(ExpectedToolCall(name=tool_name, args=args_with_defaults))
case = EvalCase(
name=name,
system_message=system_message or self.system_message,
user_message=user_message,
expected_tool_calls=expected,
rubric=rubric or self.rubric,
critics=critics,
additional_messages=additional_messages or [],
)
self.cases.append(case)
def _fill_args_with_defaults(
self, func: Callable, provided_args: dict[str, Any]
) -> dict[str, Any]:
"""
Fill in default arguments for a tool function.
Args:
func: The tool function.
provided_args: The provided arguments.
Returns:
A dictionary with default arguments filled in.
"""
sig = inspect.signature(func)
args_with_defaults = {}
for param in sig.parameters.values():
if param.name in provided_args:
args_with_defaults[param.name] = provided_args[param.name]
elif param.default is not inspect.Parameter.empty:
args_with_defaults[param.name] = param.default
else:
args_with_defaults[param.name] = None # or raise an error
return args_with_defaults
def extend_case(
self,
name: str,
user_message: str,
system_message: str | None = None,
expected_tool_calls: list[tuple[Callable, dict[str, Any]]] | None = None,
rubric: EvalRubric | None = None,
critics: list["Critic"] | None = None,
additional_messages: list[dict[str, str]] | None = None,
) -> None:
"""
Extend the last added case with new information.
Args:
name: The name of the extended case.
user_message: The new user message for this extended case.
system_message: The new system message for this extended case.
expected_tool_calls: New or updated expected tool calls.
rubric: A new rubric (if different from the last case).
critics: New critics (if different from the last case).
additional_messages: New additional messages (if different from the last case).
to be added before the new user message.
"""
if not self.cases:
raise ValueError("No cases to extend. Add a case first.")
last_case = self.cases[-1]
# Create a new message list with the previous case's messages and user message
new_additional_messages = [
*last_case.additional_messages,
]
if additional_messages:
new_additional_messages.extend(additional_messages)
expected = last_case.expected_tool_calls
if expected_tool_calls:
expected = []
for func, args in expected_tool_calls:
# Fill in default arguments here
args_with_defaults = self._fill_args_with_defaults(func, args)
tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name())
expected.append(ExpectedToolCall(name=tool_name, args=args_with_defaults))
# Create a new case, copying from the last one and updating fields
new_case = EvalCase(
name=name,
system_message=system_message or last_case.system_message,
user_message=user_message,
expected_tool_calls=expected,
rubric=rubric or self.rubric,
critics=critics or (last_case.critics.copy() if last_case.critics else None),
additional_messages=new_additional_messages,
)
self.cases.append(new_case)
async def run(self, client: AsyncOpenAI, model: str) -> dict[str, Any]:
"""
Run the evaluation suite.
Args:
client: The AsyncOpenAI client instance.
model: The model to evaluate.
Returns:
A dictionary containing the evaluation results.
"""
results: dict[str, Any] = {"model": model, "rubric": self.rubric, "cases": []}
semaphore = asyncio.Semaphore(self.max_concurrent)
tool_names = list(self.catalog.get_tool_names())
async def sem_task(case: EvalCase) -> dict[str, Any]:
async with semaphore:
# Prepare messages
messages = [{"role": "system", "content": case.system_message}]
messages.extend(case.additional_messages)
messages.append({"role": "user", "content": case.user_message})
# Get the model response
response = await client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=messages,
tool_choice="auto",
tools=(str(name) for name in tool_names),
user="eval_user",
stream=False,
)
# Extract and fill default arguments for actual tool calls
predicted_args = get_tool_args(response)
filled_actual_tool_calls = []
for tool_name, args in predicted_args:
tool = self.catalog.get_tool_by_name(tool_name)
if tool is None:
raise ValueError(f"Tool '{tool_name}' not found in catalog.")
func = tool.tool
args_with_defaults = self._fill_args_with_defaults(func, args)
filled_actual_tool_calls.append((tool_name, args_with_defaults))
# Evaluate the case
evaluation = case.evaluate(filled_actual_tool_calls)
# Prepare the result
result = {
"name": case.name,
"input": case.user_message,
"expected_tool_calls": [
{"name": tc.name, "args": tc.args} for tc in case.expected_tool_calls
],
"predicted_tool_calls": [
{"name": name, "args": args} for name, args in filled_actual_tool_calls
],
"evaluation": evaluation,
}
return result
tasks = [sem_task(case) for case in self.cases]
case_results = await asyncio.gather(*tasks)
results["cases"] = case_results
return results
def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
"""
Returns the tool arguments from the chat completion object.
Args:
chat_completion: The chat completion object.
Returns:
A list of tuples containing the tool name and arguments.
"""
tool_args_list: list[tuple[str, dict[str, Any]]] = []
message = chat_completion.choices[0].message
if message.tool_calls:
for tool_call in message.tool_calls:
tool_args_list.append((
normalize_name(tool_call.function.name),
json.loads(tool_call.function.arguments),
))
return tool_args_list
def compare_tool_name(expected: str, actual: str) -> bool:
"""
Compare the tool names by replacing all separators with the TOOL_NAME_SEPARATOR
and comparing the normalized names.
Converts names like 'Google_ListEmails' to 'Google.ListEmails' if
TOOL_NAME_SEPARATOR is '.'.
Args:
expected: The expected tool name.
actual: The actual tool name.
Returns:
True if the normalized tool names match, False otherwise.
"""
separators = "-_."
expected_normalized = normalize_name(expected, separators)
actual_normalized = normalize_name(actual, separators)
return expected_normalized.lower() == actual_normalized.lower()
def normalize_name(name: str, separators: str = "-_.") -> str:
for sep in separators:
if sep != TOOL_NAME_SEPARATOR:
name = name.replace(sep, TOOL_NAME_SEPARATOR)
return name
def tool_eval() -> Callable[[Callable], Callable]:
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(
config: Config,
model: str,
max_concurrency: int = 1,
) -> list[dict[str, Any]]:
suite = func()
if not isinstance(suite, EvalSuite):
raise TypeError("Eval function must return an EvalSuite")
suite.max_concurrent = max_concurrency
results = []
async with AsyncOpenAI(
api_key=config.api.key,
base_url=config.engine_url + "/v1", # TODO remove
) as client:
result = await suite.run(client, model)
results.append(result)
return results
wrapper.__tool_eval__ = True # type: ignore[attr-defined]
return wrapper
return decorator