import asyncio import functools import inspect import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable import numpy as np from arcade_core.converters.openai import OpenAIToolList, to_openai from arcade_core.schema import TOOL_NAME_SEPARATOR from openai import AsyncOpenAI from scipy.optimize import linear_sum_assignment from arcade_evals.critic import NoneCritic from arcade_evals.errors import WeightError if TYPE_CHECKING: from arcade_core import ToolCatalog from arcade_evals.critic import Critic @dataclass class ExpectedToolCall: """ Represents an expected tool call with the function itself and arguments. Attributes: func: The function itself. args: A dictionary containing the expected arguments for the tool. """ func: Callable args: dict[str, Any] @dataclass class NamedExpectedToolCall: """ Represents a tool call with its name and arguments. Attributes: name: The name of the tool. args: A dictionary containing the expected arguments for the tool. """ name: str args: dict[str, Any] @dataclass class EvalRubric: """ Defines the rubric for evaluating an AI model's performance on a task. Attributes: fail_threshold: The minimum score required to pass the evaluation (between 0.0 and 1.0). warn_threshold: The score threshold for issuing a warning (between 0.0 and 1.0). fail_on_tool_selection: Whether to fail the evaluation if the tool selection is incorrect. fail_on_tool_call_quantity: Whether to fail the evaluation if the number of tool calls is incorrect. tool_selection_weight: The weight assigned to the tool selection score (between 0.0 and 1.0). """ fail_threshold: float = 0.8 warn_threshold: float = 0.9 fail_on_tool_selection: bool = True fail_on_tool_call_quantity: bool = True tool_selection_weight: float = 1.0 def __str__(self) -> str: return f"Fail threshold: {self.fail_threshold}\nWarn threshold: {self.warn_threshold}\n" @dataclass class EvaluationResult: """ Represents the result of an evaluation case. Attributes: score: The normalized evaluation score (0.0-1.0). passed: Whether the evaluation passed based on the fail_threshold. warning: Whether the evaluation issued a warning based on the warn_threshold. results: A list of dictionaries containing the results for each critic. failure_reason: If the evaluation failed completely due to settings in the rubric, this field contains the reason for failure. """ score: float = 0.0 passed: bool = False warning: bool = False results: list[dict[str, Any]] = field(default_factory=list) failure_reason: str | None = None @property def fail(self) -> bool: return not self.passed and not self.warning def add( self, field: str, result: dict[str, Any], weight: float, expected: Any, actual: Any, ) -> None: """ Add a critic result to the list of critic results. Args: field: The field name for the critic result. result: A dictionary containing the critic result. weight: The weight of the critic. expected: The expected value for the critic. actual: The actual value for the critic. """ self.results.append({ "field": field, **result, "weight": weight, "expected": expected, "actual": actual, }) def score_tool_selection(self, expected: str, actual: str, weight: float) -> float: """ Score and record tool selection in results. Args: expected: The expected tool name. actual: The actual tool name. weight: The weight for tool selection. Returns: The score for the tool selection. """ score = weight if compare_tool_name(expected, actual) else 0.0 self.add( "tool_selection", {"match": compare_tool_name(expected, actual), "score": score}, weight, expected, actual, ) return score def compute_final_score(self, total_weight: float) -> None: """ Compute the final score by normalizing the total score with the total weight. """ total_score = sum(result["score"] for result in self.results) self.score = total_score / total_weight if total_weight > 0 else 0.0 @dataclass class EvalCase: """ Represents a single evaluation case within an EvalSuite. Attributes: name: A descriptive name for this evaluation case. system_message: The system message to be sent to the AI model. user_message: The user input to be sent to the AI model. expected_tool_calls: A list of NamedExpectedToolCall objects representing the expected tool calls. critics: A list of Critic objects used to evaluate tool arguments. additional_messages: Optional list of additional context messages. rubric: An EvalRubric object defining pass/fail criteria and tool selection behavior. """ name: str system_message: str user_message: str expected_tool_calls: list[NamedExpectedToolCall] critics: list["Critic"] | None = None additional_messages: list[dict[str, str]] = field(default_factory=list) rubric: EvalRubric = field(default_factory=EvalRubric) def __post_init__(self) -> None: if self.critics is not None: self._validate_critics() else: # if no critics are provided, set to empty list self.critics = [] def _validate_critics(self) -> None: """ Validate the sum of critic weights. Raises: WeightError: If the sum of critic weights exceeds 1.0. """ if not self.critics: return total_weight = sum(critic.weight for critic in self.critics) if total_weight > 1.0: raise WeightError(f"Sum of critic weights must not exceed 1.0, got {total_weight}") for critic in self.critics: if critic.weight < 0.1 and not isinstance(critic, NoneCritic): raise WeightError(f"Critic weights should be at least 0.1, got {critic.weight}") def check_tool_selection_failure(self, actual_tools: list[str]) -> bool: """ Check if tool selection failure should occur. Args: actual_tools: The list of actual tool names used. Returns: True if tool selection failure should occur, False otherwise. """ sorted_expected_tools = sorted([tc.name for tc in self.expected_tool_calls]) sorted_actual_tools = sorted(actual_tools) return self.rubric.fail_on_tool_selection and not all( compare_tool_name(expected, actual) for expected, actual in zip(sorted_expected_tools, sorted_actual_tools) ) def check_tool_call_quantity_failure(self, actual_count: int) -> bool: """ Check if tool call quantity failure should occur. Args: actual_count: The number of actual tool calls made. Returns: True if tool call quantity failure should occur, False otherwise. """ expected_count = len(self.expected_tool_calls) return self.rubric.fail_on_tool_call_quantity and expected_count != actual_count def evaluate( self, actual_tool_calls: list[tuple[str, dict[str, Any]]], ) -> EvaluationResult: """ Evaluate the actual tool calls against the expected tool calls and critics. Args: actual_tool_calls: A list of tuples containing the actual tool name and arguments. Returns: An EvaluationResult object containing the evaluation results. """ evaluation_result = EvaluationResult() actual_tools = [tool_name for tool_name, _ in actual_tool_calls] actual_count = len(actual_tool_calls) if self.check_tool_call_quantity_failure(actual_count): evaluation_result.score = 0.0 evaluation_result.passed = False expected_count = len(self.expected_tool_calls) expected_tool_names = ", ".join( tool_call.name for tool_call in self.expected_tool_calls ) evaluation_result.failure_reason = ( f"Expected {expected_count} tool call(s), but got {actual_count}. " + f"\nExpected tool calls: {expected_tool_names}.\nActual tool calls: {', '.join(actual_tools)}" ) return evaluation_result if not self.expected_tool_calls and not actual_tools: evaluation_result.score = 1.0 evaluation_result.passed = True return evaluation_result if self.check_tool_selection_failure(actual_tools): evaluation_result.score = 0.0 evaluation_result.passed = False expected_tools = [tc.name for tc in self.expected_tool_calls] evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}" return evaluation_result if not self.critics: evaluation_result.score = 1.0 evaluation_result.passed = True return evaluation_result # Create a cost matrix for the assignment problem cost_matrix = self._create_cost_matrix(actual_tool_calls, self.expected_tool_calls) # Use the Linear Sum Assignment algorithm to find the optimal assignment row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True) total_score = 0.0 total_weight = 0.0 for i, j in zip(row_ind, col_ind): if i < len(self.expected_tool_calls) and j < len(actual_tool_calls): expected = self.expected_tool_calls[i] actual_name, actual_args = actual_tool_calls[j] # Tool selection tool_selection_score = evaluation_result.score_tool_selection( expected.name, actual_name, self.rubric.tool_selection_weight ) total_score += tool_selection_score total_weight += self.rubric.tool_selection_weight # Evaluate arguments using critics for critic in self.critics: expected_value = expected.args.get(critic.critic_field) actual_value = actual_args.get(critic.critic_field) try: result = critic.evaluate(expected_value, actual_value) total_score += result["score"] total_weight += critic.weight evaluation_result.add( critic.critic_field, result, critic.weight, expected_value, actual_value, ) except Exception as e: # TODO: log or console print(f"Critic evaluation failed for field '{critic.critic_field}': {e}") evaluation_result.add( critic.critic_field, {"match": False, "score": 0.0}, critic.weight, expected_value, actual_value, ) continue # Compute the final score evaluation_result.compute_final_score(total_weight) # Set pass/fail and warning status evaluation_result.passed = evaluation_result.score >= self.rubric.fail_threshold evaluation_result.warning = ( not evaluation_result.passed and evaluation_result.score >= self.rubric.warn_threshold ) return evaluation_result def _create_cost_matrix( self, actual_tool_calls: list[tuple[str, dict[str, Any]]], expected_tool_calls: list[NamedExpectedToolCall], ) -> np.ndarray: """ Create a cost matrix for the assignment problem. Args: actual_tool_calls: A list of tuples of actual tool calls. expected_tool_calls: A list of NamedExpectedToolCall instances. Returns: A numpy array representing the cost matrix. """ num_expected = len(expected_tool_calls) num_actual = len(actual_tool_calls) n = max(num_expected, num_actual) cost_matrix = np.zeros((n, n)) for i in range(n): for j in range(n): if i < num_expected and j < num_actual: expected = expected_tool_calls[i] actual_name, actual_args = actual_tool_calls[j] score = 0.0 # Tool selection if compare_tool_name(expected.name, actual_name): score += self.rubric.tool_selection_weight # Critics evaluation for critic in self.critics: # type: ignore[union-attr] expected_value = expected.args.get(critic.critic_field) actual_value = actual_args.get(critic.critic_field) if expected_value is not None and actual_value is not None: try: result = critic.evaluate(expected_value, actual_value) score += result.get("score", 0.0) except Exception as e: print( f"Critic evaluation failed for field '{critic.critic_field}': {e}" ) cost_matrix[i, j] = score return cost_matrix @dataclass class EvalSuite: """ A suite for evaluating AI model performance on specific tasks or scenarios. EvalSuite manages a collection of EvalCases, each representing a specific test scenario. It provides methods to add cases, register tools, and run evaluations against specified models. Attributes: name: The name of the evaluation suite. system_message: The system message to be used for all cases in this suite. catalog: A ToolCatalog object containing registered tools. cases: A list of EvalCase objects representing individual test scenarios. rubric: The evaluation rubric for this case. max_concurrent: Maximum number of concurrent evaluations. """ name: str system_message: str catalog: "ToolCatalog" cases: list[EvalCase] = field(default_factory=list) rubric: EvalRubric = field(default_factory=EvalRubric) max_concurrent: int = 1 def _convert_to_named_expected_tool_call( self, tc: ExpectedToolCall | tuple[Callable, dict[str, Any]] ) -> NamedExpectedToolCall: """ Convert an ExpectedToolCall or a tuple to a NamedExpectedToolCall with default arguments populated. Args: tc: The tool call, either as an ExpectedToolCall or a tuple. Returns: A NamedExpectedToolCall instance. """ if isinstance(tc, tuple): func, args = tc else: func = tc.func args = tc.args args_with_defaults = self._fill_args_with_defaults(func, args) tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()) return NamedExpectedToolCall(name=tool_name, args=args_with_defaults) def add_case( self, name: str, user_message: str, expected_tool_calls: list[ExpectedToolCall] | list[tuple[Callable, dict[str, Any]]], critics: list["Critic"] | None = None, system_message: str | None = None, rubric: EvalRubric | None = None, additional_messages: list[dict[str, str]] | None = None, ) -> None: """ Add a new evaluation case to the suite. Args: name: The name of the evaluation case. user_message: The user's input message. expected_tool_calls: A list of expected tool calls as ExpectedToolCall instances. critics: List of critics to evaluate the tool arguments. system_message: The system message to be used. rubric: The evaluation rubric for this case. additional_messages: Optional list of additional messages for context. """ expected_tool_calls_with_defaults = [ self._convert_to_named_expected_tool_call(tc) for tc in expected_tool_calls ] # Add NoneCritics for any expected tool call fields not in the critics list critics = self._add_none_critics(expected_tool_calls_with_defaults, critics) self._validate_critics(critics, name) case = EvalCase( name=name, system_message=system_message or self.system_message, user_message=user_message, expected_tool_calls=expected_tool_calls_with_defaults, rubric=rubric or self.rubric, critics=critics, additional_messages=additional_messages or [], ) self.cases.append(case) def _add_none_critics( self, expected_tool_calls_with_defaults: list[NamedExpectedToolCall], critics: list["Critic"] | None, ) -> list["Critic"]: """ Add NoneCritics for any fields in the expected tool calls that are not already in the critics list. Args: expected_tool_calls_with_defaults: The list of expected tool calls with defaults. critics: The list of critics. Returns: The updated list of critics. """ if not critics: critics = [] critic_field_names = set() else: critic_field_names = {critic.critic_field for critic in critics} for tc in expected_tool_calls_with_defaults: for field_name in tc.args: if field_name not in critic_field_names: critics.append(NoneCritic(critic_field=field_name)) critic_field_names.add(field_name) return critics def _validate_critics(self, critics: list["Critic"] | None, name: str) -> None: """ Validate the critics. Args: critics: The list of critics. name: The name of the evaluation case. Raises: ValueError: If multiple critics are detected for the same field. """ if critics is None: return critic_fields = [critic.critic_field for critic in critics] duplicate_fields = {field for field in critic_fields if critic_fields.count(field) > 1} if duplicate_fields: raise ValueError( f"Multiple critics detected for the field(s) '{', '.join(duplicate_fields)}' in evaluation case '{name}'. Only one critic per field is permitted." ) def _fill_args_with_defaults( self, func: Callable, provided_args: dict[str, Any] ) -> dict[str, Any]: """ Fill in default arguments for a tool function. Args: func: The tool function. provided_args: The provided arguments. Returns: A dictionary with default arguments filled in. """ sig = inspect.signature(func) args_with_defaults = {} for param in sig.parameters.values(): if param.name in provided_args: args_with_defaults[param.name] = provided_args[param.name] elif param.default is not inspect.Parameter.empty: args_with_defaults[param.name] = param.default else: args_with_defaults[param.name] = None # or raise an error return args_with_defaults def extend_case( self, name: str, user_message: str, system_message: str | None = None, expected_tool_calls: list[ExpectedToolCall] | list[tuple[Callable, dict[str, Any]]] | None = None, rubric: EvalRubric | None = None, critics: list["Critic"] | None = None, additional_messages: list[dict[str, str]] | None = None, ) -> None: """ Extend the last added case with new information. Args: name: The name of the extended case. user_message: The new user message for this extended case. system_message: The new system message for this extended case. expected_tool_calls: New or updated expected tool calls. rubric: A new rubric (if different from the last case). critics: New critics (if different from the last case). additional_messages: New additional messages (if different from the last case). to be added before the new user message. """ if not self.cases: raise ValueError("No cases to extend. Add a case first.") last_case = self.cases[-1] # Create a new message list with the previous case's messages and user message new_additional_messages = [ *last_case.additional_messages, ] if additional_messages: new_additional_messages.extend(additional_messages) expected = last_case.expected_tool_calls if expected_tool_calls: expected = [self._convert_to_named_expected_tool_call(tc) for tc in expected_tool_calls] # Add NoneCritics for any expected tool call fields not in the critics list critics = self._add_none_critics( expected, critics or (last_case.critics.copy() if last_case.critics else None) ) self._validate_critics(critics, name) # Create a new case, copying from the last one and updating fields new_case = EvalCase( name=name, system_message=system_message or last_case.system_message, user_message=user_message, expected_tool_calls=expected, rubric=rubric or self.rubric, critics=critics, additional_messages=new_additional_messages, ) self.cases.append(new_case) async def run(self, client: AsyncOpenAI, model: str) -> dict[str, Any]: """ Run the evaluation suite. Args: client: The AsyncOpenAI client instance. model: The model to evaluate. Returns: A dictionary containing the evaluation results. """ results: dict[str, Any] = {"model": model, "rubric": self.rubric, "cases": []} semaphore = asyncio.Semaphore(self.max_concurrent) async def sem_task(case: EvalCase) -> dict[str, Any]: async with semaphore: # Prepare messages messages = [{"role": "system", "content": case.system_message}] messages.extend(case.additional_messages) messages.append({"role": "user", "content": case.user_message}) tools = get_formatted_tools(self.catalog, tool_format="openai") # Get the model response response = await client.chat.completions.create( # type: ignore[call-overload] model=model, messages=messages, tool_choice="auto", tools=tools, user="eval_user", seed=42, stream=False, ) # Extract and fill default arguments for actual tool calls predicted_args = get_tool_args(response) filled_actual_tool_calls = [] for tool_name, args in predicted_args: tool = self.catalog.get_tool_by_name(tool_name) if tool is None: raise ValueError(f"Tool '{tool_name}' not found in catalog.") func = tool.tool args_with_defaults = self._fill_args_with_defaults(func, args) filled_actual_tool_calls.append((tool_name, args_with_defaults)) # Evaluate the case evaluation = case.evaluate(filled_actual_tool_calls) # Prepare the result result = { "name": case.name, "input": case.user_message, "expected_tool_calls": [ {"name": tc.name, "args": tc.args} for tc in case.expected_tool_calls ], "predicted_tool_calls": [ {"name": name, "args": args} for name, args in filled_actual_tool_calls ], "evaluation": evaluation, } return result tasks = [sem_task(case) for case in self.cases] case_results = await asyncio.gather(*tasks) results["cases"] = case_results return results def get_formatted_tools(catalog: "ToolCatalog", tool_format: str = "openai") -> OpenAIToolList: """Get the formatted tools from the catalog. Args: catalog: The catalog of Arcade tools. tool_format: The format of the tools to return Returns: The formatted tools. """ if tool_format == "openai": tools = [to_openai(tool) for tool in catalog] return tools else: raise ValueError(f"Tool format for '{tool_format}' is not supported") def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]: """ Returns the tool arguments from the chat completion object. Args: chat_completion: The chat completion object. Returns: A list of tuples containing the tool name and arguments. """ tool_args_list: list[tuple[str, dict[str, Any]]] = [] message = chat_completion.choices[0].message if message.tool_calls: for tool_call in message.tool_calls: tool_args_list.append(( normalize_name(tool_call.function.name), json.loads(tool_call.function.arguments), )) return tool_args_list def compare_tool_name(expected: str, actual: str) -> bool: """ Compare the tool names by replacing all separators with the TOOL_NAME_SEPARATOR and comparing the normalized names. Converts names like 'Google_ListEmails' to 'Google.ListEmails' if TOOL_NAME_SEPARATOR is '.'. Args: expected: The expected tool name. actual: The actual tool name. Returns: True if the normalized tool names match, False otherwise. """ separators = "-_." expected_normalized = normalize_name(expected, separators) actual_normalized = normalize_name(actual, separators) return expected_normalized.lower() == actual_normalized.lower() def normalize_name(name: str, separators: str = "-_.") -> str: for sep in separators: if sep != TOOL_NAME_SEPARATOR: name = name.replace(sep, TOOL_NAME_SEPARATOR) return name def tool_eval() -> Callable[[Callable], Callable]: def decorator(func: Callable) -> Callable: @functools.wraps(func) async def wrapper( provider_api_key: str, model: str, max_concurrency: int = 1, ) -> list[dict[str, Any]]: suite = func() if not isinstance(suite, EvalSuite): raise TypeError("Eval function must return an EvalSuite") suite.max_concurrent = max_concurrency results = [] async with AsyncOpenAI( api_key=provider_api_key, ) as client: result = await suite.run(client, model) results.append(result) return results wrapper.__tool_eval__ = True # type: ignore[attr-defined] return wrapper return decorator