Evals New Features (#208)

# PR Description
This PR adds ~~four~~ three improvements to evals.

~~## 1. Add parameterized eval cases~~
~~Adds a new method named `add_parameterized_case`. Just like pytest’s
parameterized tests, eval cases can be parameterized with multiple user
messages. Adds a case to the `EvalSuite` for each user message. All
cases have the same expected tool call(s), params, additional_messages.
This reduces duplicate code and makes it easy to observe how a model
performs based on increasingly more difficult prompts.~~
```python
""" NO LONGER IN THIS PR
user_messages = [
    "Call the delete tweet by id tool with the tweet ID '148975632'.",
    "Delete the tweet with ID '148975632'.",
    "I don't want to have this tweet (148975632) on my account anymore.",
    "do the opposite of post for https://x.com/x/status/148975632",
]

suite.add_parameterized_case(
    name="Delete a tweet by ID",
    user_messages=user_messages,
    expected_tool_calls=[
        ExpectedToolCall(
            func=delete_tweet_by_id,
            args={"tweet_id": "148975632"},
        )
    ],
    critics=[
        BinaryCritic(
            critic_field="tweet_id",
            weight=1.0,
        ),
    ],
)
"""
```

~~PASSED Delete a tweet by ID (user_message 1 of 4) -- Score: 100.00%~~
~~PASSED Delete a tweet by ID (user_message 2 of 4) -- Score: 100.00%~~
~~PASSED Delete a tweet by ID (user_message 3 of 4) -- Score: 100.00%~~
~~FAILED Delete a tweet by ID (user_message 4 of 4) -- Score: 0.00%~~
~~Summary -- Total: 4 -- Passed: 3 -- Failed: 1~~

## 2. Parameters that are not explicitly criticized are assigned a
`NoneCritic`.
A NoneCritic has no effect on the evaluation results and does not
actually evaluate. Parameters that have a NoneCritic will be displayed
as ‘un-criticized’ in the evaluation summary (if `-d` flag is used).

![image](https://github.com/user-attachments/assets/300756ec-9b53-436a-9cf9-fc61d0b00c01)


## 3. Add a hardcoded `seed` parameter for evals.
The seed parameter aides in receiving (mostly) consistent outputs -
aiding in reproducibility for evaluations.

## 4. Disallow more than one critic for the same field.
Raises a `ValueError` if more than one critic is assigned to a field.

---------

Co-authored-by: Eric Gustin <eric@arcade-ai.com>
This commit is contained in:
Eric Gustin 2025-02-05 15:22:08 -08:00 committed by GitHub
parent 149c25d967
commit be2539602f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 526 additions and 306 deletions

View file

@ -206,20 +206,31 @@ def _format_evaluation(evaluation: "EvaluationResult") -> str:
result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}")
else:
for critic_result in evaluation.results:
match_color = "green" if critic_result["match"] else "red"
is_criticized = critic_result.get("is_criticized", True)
match_color = (
"yellow" if not is_criticized else "green" if critic_result["match"] else "red"
)
field = critic_result["field"]
score = critic_result["score"]
weight = critic_result["weight"]
expected = critic_result["expected"]
actual = critic_result["actual"]
result_lines.append(
f"[bold]{field}:[/bold] "
f"[{match_color}]Match: {critic_result['match']}"
f"\n Score: {score:.2f}/{weight:.2f}[/{match_color}]"
f"\n Expected: {expected}"
f"\n Actual: {actual}"
)
if is_criticized:
result_lines.append(
f"[bold]{field}:[/bold] "
f"[{match_color}]Match: {critic_result['match']}"
f"\n Score: {score:.2f}/{weight:.2f}[/{match_color}]"
f"\n Expected: {expected}"
f"\n Actual: {actual}"
)
else:
result_lines.append(
f"[bold]{field}:[/bold] "
f"[{match_color}]Un-criticized[/{match_color}]"
f"\n Expected: {expected}"
f"\n Actual: {actual}"
)
return "\n".join(result_lines)

View file

@ -1,4 +1,4 @@
from .critic import BinaryCritic, DatetimeCritic, NumericCritic, SimilarityCritic
from .critic import BinaryCritic, DatetimeCritic, NoneCritic, NumericCritic, SimilarityCritic
from .eval import EvalRubric, EvalSuite, ExpectedToolCall, NamedExpectedToolCall, tool_eval
__all__ = [
@ -8,6 +8,7 @@ __all__ = [
"EvalSuite",
"ExpectedToolCall",
"NamedExpectedToolCall",
"NoneCritic",
"NumericCritic",
"SimilarityCritic",
"tool_eval",

View file

@ -23,6 +23,25 @@ class Critic(ABC):
pass
@dataclass
class NoneCritic(Critic):
"""
A critic that has no effect on the evaluation results and does not actually evaluate.
If a critic is not found for an evaluation case's field, then
a NoneCritic is used to indicate that the field was not criticized.
"""
weight: float = 0.0
def __post_init__(self) -> None:
self.weight = 0.0
super().__post_init__()
def evaluate(self, expected: Any, actual: Any) -> dict[str, Any]:
return {"match": None, "score": self.weight, "is_criticized": False}
@dataclass
class BinaryCritic(Critic):
"""

View file

@ -19,6 +19,7 @@ except ImportError:
from openai import AsyncOpenAI
from arcade.sdk.errors import WeightError
from arcade.sdk.eval.critic import NoneCritic
if TYPE_CHECKING:
from arcade.sdk import ToolCatalog
@ -201,7 +202,7 @@ class EvalCase:
raise WeightError(f"Sum of critic weights must not exceed 1.0, got {total_weight}")
for critic in self.critics:
if critic.weight < 0.1:
if critic.weight < 0.1 and not isinstance(critic, NoneCritic):
raise WeightError(f"Critic weights should be at least 0.1, got {critic.weight}")
def check_tool_selection_failure(self, actual_tools: list[str]) -> bool:
@ -463,6 +464,11 @@ class EvalSuite:
self._convert_to_named_expected_tool_call(tc) for tc in expected_tool_calls
]
# Add NoneCritics for any expected tool call fields not in the critics list
critics = self._add_none_critics(expected_tool_calls_with_defaults, critics)
self._validate_critics(critics, name)
case = EvalCase(
name=name,
system_message=system_message or self.system_message,
@ -474,6 +480,54 @@ class EvalSuite:
)
self.cases.append(case)
def _add_none_critics(
self,
expected_tool_calls_with_defaults: list[NamedExpectedToolCall],
critics: list["Critic"] | None,
) -> list["Critic"]:
"""
Add NoneCritics for any fields in the expected tool calls that are not already in the critics list.
Args:
expected_tool_calls_with_defaults: The list of expected tool calls with defaults.
critics: The list of critics.
Returns:
The updated list of critics.
"""
if not critics:
critics = []
critic_field_names = set()
else:
critic_field_names = {critic.critic_field for critic in critics}
for tc in expected_tool_calls_with_defaults:
for field_name in tc.args:
if field_name not in critic_field_names:
critics.append(NoneCritic(critic_field=field_name))
critic_field_names.add(field_name)
return critics
def _validate_critics(self, critics: list["Critic"] | None, name: str) -> None:
"""
Validate the critics.
Args:
critics: The list of critics.
name: The name of the evaluation case.
Raises:
ValueError: If multiple critics are detected for the same field.
"""
if critics is None:
return
critic_fields = [critic.critic_field for critic in critics]
duplicate_fields = {field for field in critic_fields if critic_fields.count(field) > 1}
if duplicate_fields:
raise ValueError(
f"Multiple critics detected for the field(s) '{', '.join(duplicate_fields)}' in evaluation case '{name}'. Only one critic per field is permitted."
)
def _fill_args_with_defaults(
self, func: Callable, provided_args: dict[str, Any]
) -> dict[str, Any]:
@ -539,6 +593,13 @@ class EvalSuite:
if expected_tool_calls:
expected = [self._convert_to_named_expected_tool_call(tc) for tc in expected_tool_calls]
# Add NoneCritics for any expected tool call fields not in the critics list
critics = self._add_none_critics(
expected, critics or (last_case.critics.copy() if last_case.critics else None)
)
self._validate_critics(critics, name)
# Create a new case, copying from the last one and updating fields
new_case = EvalCase(
name=name,
@ -546,7 +607,7 @@ class EvalSuite:
user_message=user_message,
expected_tool_calls=expected,
rubric=rubric or self.rubric,
critics=critics or (last_case.critics.copy() if last_case.critics else None),
critics=critics,
additional_messages=new_additional_messages,
)
self.cases.append(new_case)
@ -581,6 +642,7 @@ class EvalSuite:
tool_choice="auto",
tools=(str(name) for name in tool_names),
user="eval_user",
seed=42,
stream=False,
)

View file

@ -1,113 +1,37 @@
from datetime import timedelta
from unittest.mock import Mock
import pytest
import pytz
from dateutil import parser
from arcade.sdk import tool
from arcade.sdk.errors import WeightError
from arcade.sdk.eval import (
BinaryCritic,
DatetimeCritic,
EvalRubric,
ExpectedToolCall,
NamedExpectedToolCall,
NumericCritic,
NoneCritic,
SimilarityCritic,
)
from arcade.sdk.eval.eval import EvalCase, EvalSuite, EvaluationResult
# Test BinaryCritic.evaluate()
@tool
def mock_tool(param1: str):
pass
@pytest.mark.parametrize(
"expected, actual, weight, expected_match, expected_score",
[
("value", "value", 1.0, True, 1.0),
("value", "different", 1.0, False, 0.0),
(10, 10, 0.5, True, 0.5),
(10, 20, 0.5, False, 0.0),
],
)
def test_binary_critic_evaluate(expected, actual, weight, expected_match, expected_score):
"""
Test the BinaryCritic's evaluate method to ensure it correctly computes
the match and score based on expected and actual values.
"""
critic = BinaryCritic(critic_field="test_field", weight=weight)
result = critic.evaluate(expected=expected, actual=actual)
assert result["match"] == expected_match
assert result["score"] == expected_score
@tool
def mock_tool_no_args():
pass
# Test NumericCritic.evaluate()
@pytest.mark.parametrize(
"expected, actual, value_range, weight, match_threshold, expected_match, expected_score",
[
(5, 5, (0, 10), 1.0, 0.8, True, 1.0),
(5, 6, (0, 10), 1.0, 0.8, True, 0.9),
(0, 10, (0, 10), 1.0, 0.8, False, 0.0),
(2, 8, (0, 10), 1.0, 0.5, False, 0.4),
(50, 60, (0, 100), 0.5, 0.9, True, 0.45),
],
)
def test_numeric_critic_evaluate(
expected, actual, value_range, weight, match_threshold, expected_match, expected_score
@tool
def mock_tool_multiple_args(
param1: str, param2: str, param3: str = "value3", param4: str = "value4"
):
"""
Test the NumericCritic's evaluate method to ensure it calculates
the correct score based on the proportion of the difference between
expected and actual values within a specified range.
"""
critic = NumericCritic(
critic_field="number",
weight=weight,
value_range=value_range,
match_threshold=match_threshold,
)
result = critic.evaluate(expected=expected, actual=actual)
assert result["match"] == expected_match
assert pytest.approx(result["score"], 0.01) == expected_score
# Test SimilarityCritic.evaluate()
@pytest.mark.parametrize(
"expected, actual, weight, similarity_threshold, expected_match, min_expected_score",
[
("hello world", "hello world", 1.0, 0.8, True, 1.0),
("hello world", "hello", 1.0, 0.8, False, 0.0),
("The quick brown fox", "The quick brown fox jumps over the lazy dog", 1.0, 0.5, True, 0.5),
("data science", "machine learning", 0.5, 0.3, False, 0.0),
],
)
def test_similarity_critic_evaluate(
expected, actual, weight, similarity_threshold, expected_match, min_expected_score
):
"""
Test the SimilarityCritic's evaluate method to ensure it computes
the similarity score between expected and actual strings and determines
the match correctly based on the similarity threshold.
"""
critic = SimilarityCritic(
critic_field="text",
weight=weight,
similarity_threshold=similarity_threshold,
)
result = critic.evaluate(expected=expected, actual=actual)
assert result["match"] == expected_match
assert result["score"] >= min_expected_score
assert result["score"] >= 0.0
assert result["score"] <= weight + 1e-6 # Allow a small epsilon for floating-point comparison
pass
# Test EvaluationResult accumulation and pass/fail logic
def test_evaluation_result_accumulation():
"""
Test that EvaluationResult correctly accumulates scores and determines
@ -135,8 +59,6 @@ def test_evaluation_result_accumulation():
# Test EvalCase.evaluate()
def test_eval_case_evaluate():
"""
Test EvalCase's evaluate method to ensure it calculates the overall score
@ -184,8 +106,6 @@ def test_eval_case_evaluate():
# Test EvalCase with mismatched tool calls
def test_eval_case_evaluate_mismatched_tools():
"""
Test EvalCase's evaluate method when the actual tool calls do not match
@ -220,8 +140,6 @@ def test_eval_case_evaluate_mismatched_tools():
# Test EvalCase with multiple critics and weights
def test_eval_case_multiple_critics():
"""
Test EvalCase's evaluate method with multiple critics having different weights
@ -261,8 +179,6 @@ def test_eval_case_multiple_critics():
# Test EvalCase with missing expected and actual values in args
def test_eval_case_with_none_values():
"""
Test that when expected or actual values are None, the critic evaluates them appropriately.
@ -290,196 +206,11 @@ def test_eval_case_with_none_values():
assert result.score == 2.0 / 2.0 # Full score (tool selection + critic score)
# Test that WeightError is raised for invalid critic weights
@pytest.mark.parametrize(
"critic_class, weight",
[
(BinaryCritic, -0.1),
(BinaryCritic, 1.1),
(NumericCritic, -0.5),
(SimilarityCritic, 1.5),
],
)
def test_critic_invalid_weight(critic_class, weight):
"""
Test that initializing a critic with an invalid weight raises a WeightError.
"""
with pytest.raises(WeightError):
if critic_class == NumericCritic:
critic_class(critic_field="test_field", weight=weight, value_range=(0, 1))
elif critic_class == SimilarityCritic:
critic_class(critic_field="test_field", weight=weight)
else:
critic_class(critic_field="test_field", weight=weight)
# Test NumericCritic with invalid value range
def test_numeric_critic_invalid_range():
"""
Test that initializing a NumericCritic with an invalid value range raises a ValueError.
"""
with pytest.raises(ValueError):
NumericCritic(critic_field="number", weight=1.0, value_range=(10, 0)) # Invalid range
# Test SimilarityCritic with unsupported metric
def test_similarity_critic_unsupported_metric():
"""
Test that initializing a SimilarityCritic with an unsupported metric raises a ValueError.
"""
with pytest.raises(ValueError):
SimilarityCritic(critic_field="text", weight=1.0, metric="unsupported_metric")
# Test DatetimeCritic
# Parameterized tests for DatetimeCritic with various datetime formats and default timezones
@pytest.mark.parametrize(
"critic_params, expected, actual, expected_match, expected_score",
[
# Test with time component and timezone
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00-07:00",
"2024-09-26T12:00:00-07:00",
True,
1.0,
),
# Test without time component (dates only)
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26",
"2024-09-26",
True,
1.0,
),
# Test with and without timezone (assumes UTC)
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00Z",
"2024-09-26T12:00:00",
True,
1.0,
),
# Test naive datetimes
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00",
"2024-09-26T12:00:00",
True,
1.0,
),
],
)
def test_datetime_critic_basic(critic_params, expected, actual, expected_match, expected_score):
"""
Test DatetimeCritic with various datetime formats and default timezones.
"""
critic = DatetimeCritic(**critic_params)
result = critic.evaluate(expected, actual)
assert result["match"] == expected_match
assert result["score"] == expected_score
# Parameterized tests for DatetimeCritic's handling of tolerances and max differences
@pytest.mark.parametrize(
"critic_params, expected, actual, expected_match, expected_score_func",
[
# Test time difference within tolerance
(
{"critic_field": "start_datetime", "weight": 1.0, "tolerance": timedelta(seconds=60)},
"2024-09-26T12:00:00",
"2024-09-26T12:00:30",
True,
lambda critic: critic.weight,
),
# Test time difference outside tolerance but within max_difference
(
{
"critic_field": "start_datetime",
"weight": 1.0,
"tolerance": timedelta(seconds=60),
"max_difference": timedelta(minutes=5),
},
"2024-09-26T12:00:00",
"2024-09-26T12:04:00",
False,
lambda critic: critic.weight * (1 - (240 / 300)),
),
# Test time difference exceeds max_difference
(
{
"critic_field": "start_datetime",
"weight": 1.0,
"max_difference": timedelta(minutes=5),
},
"2024-09-26T12:00:00",
"2024-09-26T12:10:00",
False,
lambda critic: 0.0,
),
],
)
def test_datetime_critic_tolerances(
critic_params, expected, actual, expected_match, expected_score_func
):
"""
Test DatetimeCritic's handling of tolerances and max differences.
"""
critic = DatetimeCritic(**critic_params)
result = critic.evaluate(expected, actual)
assert result["match"] == expected_match
expected_score = expected_score_func(critic)
assert pytest.approx(result["score"], abs=1e-6) == expected_score
def test_datetime_critic_naive_and_timezone_aware():
"""
Test DatetimeCritic when comparing naive and timezone-aware datetimes.
"""
critic = DatetimeCritic(critic_field="start_datetime", weight=1.0)
expected = "2024-09-26T12:00:00Z"
actual = "2024-09-26T07:00:00"
result = critic.evaluate(expected, actual)
assert result["match"] is False
# Compute expected score based on time difference
expected_dt = parser.parse(expected)
actual_dt = parser.parse(actual)
if actual_dt.tzinfo is None:
actual_dt = pytz.utc.localize(actual_dt)
if expected_dt.tzinfo is None:
expected_dt = pytz.utc.localize(expected_dt)
time_diff_seconds = abs((expected_dt - actual_dt).total_seconds())
if time_diff_seconds <= critic.tolerance.total_seconds():
expected_score = critic.weight
elif time_diff_seconds >= critic.max_difference.total_seconds():
expected_score = 0.0
else:
ratio = 1 - (time_diff_seconds / critic.max_difference.total_seconds())
expected_score = critic.weight * ratio
assert pytest.approx(result["score"], abs=1e-6) == expected_score
# Test EvalSuite.add_case()
def test_eval_suite_add_case():
"""
Test that add_case correctly adds a new evaluation case to the suite.
"""
@tool
def mock_tool(param: str):
pass
mock_catalog = Mock()
mock_catalog.find_tool_by_func.return_value.get_fully_qualified_name.return_value = "MockTool"
@ -488,11 +219,11 @@ def test_eval_suite_add_case():
expected_tool_calls = [
ExpectedToolCall(
func=mock_tool,
args={"param": "value"},
args={"param1": "value"},
),
(
mock_tool,
{"param": "value"},
{"param1": "value"},
),
]
@ -509,10 +240,10 @@ def test_eval_suite_add_case():
assert case.user_message == "User message"
assert case.system_message == "System message"
assert case.expected_tool_calls[0] == NamedExpectedToolCall(
name="MockTool", args={"param": "value"}
name="MockTool", args={"param1": "value"}
)
assert case.expected_tool_calls[1] == NamedExpectedToolCall(
name="MockTool", args={"param": "value"}
name="MockTool", args={"param1": "value"}
)
@ -521,11 +252,6 @@ def test_eval_suite_extend_case():
"""
Test that extend_case correctly extends the last added case with new information.
"""
@tool
def mock_tool(param: str):
pass
mock_catalog = Mock()
mock_catalog.find_tool_by_func.return_value.get_fully_qualified_name.return_value = "MockTool"
@ -534,11 +260,11 @@ def test_eval_suite_extend_case():
expected_tool_calls = [
ExpectedToolCall(
func=mock_tool,
args={"param": "value"},
args={"param1": "value"},
),
(
mock_tool,
{"param": "value"},
{"param1": "value"},
),
]
@ -564,8 +290,124 @@ def test_eval_suite_extend_case():
assert extended_case.system_message == "System message"
assert len(extended_case.expected_tool_calls) == 2
assert extended_case.expected_tool_calls[0] == NamedExpectedToolCall(
name="MockTool", args={"param": "value"}
name="MockTool", args={"param1": "value"}
)
assert extended_case.expected_tool_calls[1] == NamedExpectedToolCall(
name="MockTool", args={"param": "value"}
name="MockTool", args={"param1": "value"}
)
def test_eval_suite_validate_critics_raises_value_error():
"""
Test that validate_critics raises a ValueError if multiple critics are detected for the same field.
"""
mock_catalog = Mock()
suite = EvalSuite(name="TestSuite", system_message="System message", catalog=mock_catalog)
case_name = "TestCase"
critics = [
BinaryCritic(critic_field="param", weight=0.5),
SimilarityCritic(critic_field="param", weight=0.5),
]
with pytest.raises(ValueError):
suite._validate_critics(critics, case_name)
def test_eval_suite_validate_critics_no_error():
"""
Test that validate_critics does not raise an error when critics are valid.
"""
mock_catalog = Mock()
suite = EvalSuite(name="TestSuite", system_message="System message", catalog=mock_catalog)
case_name = "TestCase"
critics = [
BinaryCritic(critic_field="param1", weight=0.5),
]
suite._validate_critics(critics, case_name)
@pytest.mark.parametrize(
"expected_tool_calls, critics, expected_critics_count, expected_critics_types",
[
(
# Test case 1: No arguments, expect no critics
[NamedExpectedToolCall(name="MockToolNoArgs", args={})],
None,
0,
[],
),
(
# Test case 2: Single argument, expect one NoneCritic
[NamedExpectedToolCall(name="MockTool", args={"param1": "value"})],
None,
1,
[(NoneCritic, "param1")],
),
(
# Test case 3: Multiple arguments with some critics, expect BinaryCritics for specified fields and NoneCritics for others
[
NamedExpectedToolCall(
name="MockToolMultipleArgs",
args={
"param1": "value1",
"param2": "value2",
"param3": "value3",
"param4": "value4",
},
)
],
[
BinaryCritic(critic_field="param1", weight=0.5),
BinaryCritic(critic_field="param2", weight=0.5),
],
4,
[
(BinaryCritic, "param1"),
(BinaryCritic, "param2"),
(NoneCritic, "param3"),
(NoneCritic, "param4"),
],
),
(
# Test case 4: Mixed tool calls with multiple critics, expect BinaryCritics for specified fields and NoneCritics for others
[
NamedExpectedToolCall(name="MockTool", args={"param1": "value"}),
NamedExpectedToolCall(name="MockToolNoArgs", args={}),
NamedExpectedToolCall(
name="MockToolMultipleArgs",
args={
"param1": "value1",
"param2": "value2",
"param3": "value3",
"param4": "value4",
},
),
],
[
BinaryCritic(critic_field="param1", weight=0.3),
BinaryCritic(critic_field="param2", weight=0.3),
BinaryCritic(critic_field="param3", weight=0.3),
],
4,
[
(BinaryCritic, "param1"),
(BinaryCritic, "param2"),
(BinaryCritic, "param3"),
(NoneCritic, "param4"),
],
),
],
)
def test_eval_suite_add_none_critics(
expected_tool_calls, critics, expected_critics_count, expected_critics_types
):
mock_catalog = Mock()
suite = EvalSuite(name="TestSuite", system_message="System message", catalog=mock_catalog)
critics_with_none = suite._add_none_critics(expected_tool_calls, critics)
assert len(critics_with_none) == expected_critics_count
for i, (expected_type, expected_field) in enumerate(expected_critics_types):
assert isinstance(critics_with_none[i], expected_type)
assert critics_with_none[i].critic_field == expected_field

View file

@ -0,0 +1,286 @@
from datetime import timedelta
import pytest
import pytz
from dateutil import parser
from arcade.sdk.errors import WeightError
from arcade.sdk.eval import (
BinaryCritic,
DatetimeCritic,
NoneCritic,
NumericCritic,
SimilarityCritic,
)
# Test NoneCritic initialization
@pytest.mark.parametrize("weight, expected_weight", [(0.0, 0.0), (0.5, 0.0)])
def test_none_critic_initialization(weight, expected_weight):
field_name = "my_field"
critic = NoneCritic(weight=weight, critic_field=field_name)
assert critic.weight == expected_weight
assert critic.critic_field == field_name
# Test NoneCritic.evaluate()
def test_none_critic_evaluate():
critic = NoneCritic(critic_field="my_field")
result = critic.evaluate(expected="expected_value", actual="actual_value")
assert result["match"] is None
assert result["score"] == 0.0
assert result["is_criticized"] is False
# Test BinaryCritic.evaluate()
@pytest.mark.parametrize(
"expected, actual, weight, expected_match, expected_score",
[
("value", "value", 1.0, True, 1.0),
("value", "different", 1.0, False, 0.0),
(10, 10, 0.5, True, 0.5),
(10, 20, 0.5, False, 0.0),
],
)
def test_binary_critic_evaluate(expected, actual, weight, expected_match, expected_score):
"""
Test the BinaryCritic's evaluate method to ensure it correctly computes
the match and score based on expected and actual values.
"""
critic = BinaryCritic(critic_field="test_field", weight=weight)
result = critic.evaluate(expected=expected, actual=actual)
assert result["match"] == expected_match
assert result["score"] == expected_score
# Test NumericCritic.evaluate()
@pytest.mark.parametrize(
"expected, actual, value_range, weight, match_threshold, expected_match, expected_score",
[
(5, 5, (0, 10), 1.0, 0.8, True, 1.0),
(5, 6, (0, 10), 1.0, 0.8, True, 0.9),
(0, 10, (0, 10), 1.0, 0.8, False, 0.0),
(2, 8, (0, 10), 1.0, 0.5, False, 0.4),
(50, 60, (0, 100), 0.5, 0.9, True, 0.45),
],
)
def test_numeric_critic_evaluate(
expected, actual, value_range, weight, match_threshold, expected_match, expected_score
):
"""
Test the NumericCritic's evaluate method to ensure it calculates
the correct score based on the proportion of the difference between
expected and actual values within a specified range.
"""
critic = NumericCritic(
critic_field="number",
weight=weight,
value_range=value_range,
match_threshold=match_threshold,
)
result = critic.evaluate(expected=expected, actual=actual)
assert result["match"] == expected_match
assert pytest.approx(result["score"], 0.01) == expected_score
# Test SimilarityCritic.evaluate()
@pytest.mark.parametrize(
"expected, actual, weight, similarity_threshold, expected_match, min_expected_score",
[
("hello world", "hello world", 1.0, 0.8, True, 1.0),
("hello world", "hello", 1.0, 0.8, False, 0.0),
("The quick brown fox", "The quick brown fox jumps over the lazy dog", 1.0, 0.5, True, 0.5),
("data science", "machine learning", 0.5, 0.3, False, 0.0),
],
)
def test_similarity_critic_evaluate(
expected, actual, weight, similarity_threshold, expected_match, min_expected_score
):
"""
Test the SimilarityCritic's evaluate method to ensure it computes
the similarity score between expected and actual strings and determines
the match correctly based on the similarity threshold.
"""
critic = SimilarityCritic(
critic_field="text",
weight=weight,
similarity_threshold=similarity_threshold,
)
result = critic.evaluate(expected=expected, actual=actual)
assert result["match"] == expected_match
assert result["score"] >= min_expected_score
assert result["score"] >= 0.0
assert result["score"] <= weight + 1e-6 # Allow a small epsilon for floating-point comparison
# Test that WeightError is raised for invalid critic weights
@pytest.mark.parametrize(
"critic_class, weight",
[
(BinaryCritic, -0.1),
(BinaryCritic, 1.1),
(NumericCritic, -0.5),
(SimilarityCritic, 1.5),
],
)
def test_critic_invalid_weight(critic_class, weight):
"""
Test that initializing a critic with an invalid weight raises a WeightError.
"""
with pytest.raises(WeightError):
if critic_class == NumericCritic:
critic_class(critic_field="test_field", weight=weight, value_range=(0, 1))
elif critic_class == SimilarityCritic:
critic_class(critic_field="test_field", weight=weight)
else:
critic_class(critic_field="test_field", weight=weight)
# Test NumericCritic with invalid value range
def test_numeric_critic_invalid_range():
"""
Test that initializing a NumericCritic with an invalid value range raises a ValueError.
"""
with pytest.raises(ValueError):
NumericCritic(critic_field="number", weight=1.0, value_range=(10, 0)) # Invalid range
# Test SimilarityCritic with unsupported metric
def test_similarity_critic_unsupported_metric():
"""
Test that initializing a SimilarityCritic with an unsupported metric raises a ValueError.
"""
with pytest.raises(ValueError):
SimilarityCritic(critic_field="text", weight=1.0, metric="unsupported_metric")
# Test DatetimeCritic
# Parameterized tests for DatetimeCritic with various datetime formats and default timezones
@pytest.mark.parametrize(
"critic_params, expected, actual, expected_match, expected_score",
[
# Test with time component and timezone
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00-07:00",
"2024-09-26T12:00:00-07:00",
True,
1.0,
),
# Test without time component (dates only)
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26",
"2024-09-26",
True,
1.0,
),
# Test with and without timezone (assumes UTC)
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00Z",
"2024-09-26T12:00:00",
True,
1.0,
),
# Test naive datetimes
(
{"critic_field": "start_datetime", "weight": 1.0},
"2024-09-26T12:00:00",
"2024-09-26T12:00:00",
True,
1.0,
),
],
)
def test_datetime_critic_basic(critic_params, expected, actual, expected_match, expected_score):
"""
Test DatetimeCritic with various datetime formats and default timezones.
"""
critic = DatetimeCritic(**critic_params)
result = critic.evaluate(expected, actual)
assert result["match"] == expected_match
assert result["score"] == expected_score
# Parameterized tests for DatetimeCritic's handling of tolerances and max differences
@pytest.mark.parametrize(
"critic_params, expected, actual, expected_match, expected_score_func",
[
# Test time difference within tolerance
(
{"critic_field": "start_datetime", "weight": 1.0, "tolerance": timedelta(seconds=60)},
"2024-09-26T12:00:00",
"2024-09-26T12:00:30",
True,
lambda critic: critic.weight,
),
# Test time difference outside tolerance but within max_difference
(
{
"critic_field": "start_datetime",
"weight": 1.0,
"tolerance": timedelta(seconds=60),
"max_difference": timedelta(minutes=5),
},
"2024-09-26T12:00:00",
"2024-09-26T12:04:00",
False,
lambda critic: critic.weight * (1 - (240 / 300)),
),
# Test time difference exceeds max_difference
(
{
"critic_field": "start_datetime",
"weight": 1.0,
"max_difference": timedelta(minutes=5),
},
"2024-09-26T12:00:00",
"2024-09-26T12:10:00",
False,
lambda critic: 0.0,
),
],
)
def test_datetime_critic_tolerances(
critic_params, expected, actual, expected_match, expected_score_func
):
"""
Test DatetimeCritic's handling of tolerances and max differences.
"""
critic = DatetimeCritic(**critic_params)
result = critic.evaluate(expected, actual)
assert result["match"] == expected_match
expected_score = expected_score_func(critic)
assert pytest.approx(result["score"], abs=1e-6) == expected_score
def test_datetime_critic_naive_and_timezone_aware():
"""
Test DatetimeCritic when comparing naive and timezone-aware datetimes.
"""
critic = DatetimeCritic(critic_field="start_datetime", weight=1.0)
expected = "2024-09-26T12:00:00Z"
actual = "2024-09-26T07:00:00"
result = critic.evaluate(expected, actual)
assert result["match"] is False
# Compute expected score based on time difference
expected_dt = parser.parse(expected)
actual_dt = parser.parse(actual)
if actual_dt.tzinfo is None:
actual_dt = pytz.utc.localize(actual_dt)
if expected_dt.tzinfo is None:
expected_dt = pytz.utc.localize(expected_dt)
time_diff_seconds = abs((expected_dt - actual_dt).total_seconds())
if time_diff_seconds <= critic.tolerance.total_seconds():
expected_score = critic.weight
elif time_diff_seconds >= critic.max_difference.total_seconds():
expected_score = 0.0
else:
ratio = 1 - (time_diff_seconds / critic.max_difference.total_seconds())
expected_score = critic.weight * ratio
assert pytest.approx(result["score"], abs=1e-6) == expected_score

View file

@ -138,7 +138,6 @@ def docs_eval_suite() -> EvalSuite:
critics=[
BinaryCritic(critic_field="document_id", weight=0.3),
SimilarityCritic(critic_field="text_content", weight=0.3),
BinaryCritic(critic_field="document_id", weight=0.3),
],
additional_messages=additional_messages,
)