Eval Suite additions (#43)

Merge of #38 as rebase was terrible. @nbarbettini

---------

Co-authored-by: Nate Barbettini <nate@arcade-ai.com>
This commit is contained in:
Sam Partee 2024-09-19 22:08:39 -07:00 committed by GitHub
parent 447058f0ce
commit 5b7370c3f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 81 additions and 14 deletions

View file

@ -224,10 +224,6 @@ def _format_evaluation(evaluation: "EvaluationResult") -> str:
A formatted string representation of the evaluation details.
"""
result_lines = []
# Include overall final score
result_lines.append(f"[bold]Final Score:[/bold] {evaluation.score:.2f}\n")
for critic_result in evaluation.results:
match_color = "green" if critic_result["match"] else "red"
field = critic_result["field"]

View file

@ -0,0 +1,65 @@
from arcade.core.catalog import ToolCatalog
from arcade_arithmetic.tools.arithmetic import add, sqrt
from arcade.sdk.eval import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
tool_eval,
)
# Evaluation rubric
rubric = EvalRubric(
fail_threshold=0.85,
warn_threshold=0.95,
)
# TODO: add_toolkit didn't work
catalog = ToolCatalog()
catalog.add_tool(add)
catalog.add_tool(sqrt)
@tool_eval("gpt-4o-mini")
def arithmetic_eval_suite():
suite = EvalSuite(
name="Arithmetic Tools Evaluation",
system="You are an AI assistant with access to arithmetic tools. Use them to help the user with their math-related tasks.",
catalog=catalog,
rubric=rubric,
)
suite.add_case(
name="Add two large numbers",
user_message="Add 12345 and 987654321",
expected_tool_calls=[
ExpectedToolCall(
"Add",
args={
"a": 12345,
"b": 987654321,
},
)
],
rubric=rubric,
critics=[
BinaryCritic(
critic_field="a", weight=0.5
), # TODO: weight should be optional
BinaryCritic(critic_field="b", weight=0.5),
],
)
suite.add_case(
name="Take the square root of a large number",
user_message="What is the square root of 3224990521?",
expected_tool_calls=[ExpectedToolCall(lambda: sqrt(3224990521))],
rubric=rubric,
critics=[
BinaryCritic(critic_field="a", weight=1.0),
],
)
return suite

View file

@ -21,7 +21,10 @@ from arcade.sdk.auth import SlackUser
)
def send_dm_to_user(
context: ToolContext,
user_name: Annotated[str, "The Slack username of the person you want to message"],
user_name: Annotated[
str,
"The Slack username of the person you want to message. Slack usernames are ALWAYS lowercase.",
],
message: Annotated[str, "The message you want to send"],
):
"""Send a direct message to a user in Slack."""
@ -82,7 +85,8 @@ def format_users(userListResponse: dict) -> str:
def send_message_to_channel(
context: ToolContext,
channel_name: Annotated[
str, "The Slack channel name where you want to send the message"
str,
"The Slack channel name where you want to send the message. Slack channel names are ALWAYS lowercase.",
],
message: Annotated[str, "The message you want to send"],
):

View file

@ -65,8 +65,8 @@ def slack_eval_suite() -> EvalSuite:
)
],
critics=[
SimilarityCritic(critic_field="user_name", weight=0.4),
SimilarityCritic(critic_field="message", weight=0.6),
SimilarityCritic(critic_field="user_name", weight=0.6),
SimilarityCritic(critic_field="message", weight=0.4),
],
)
@ -83,8 +83,8 @@ def slack_eval_suite() -> EvalSuite:
)
],
critics=[
BinaryCritic(critic_field="user_name", weight=0.5),
SimilarityCritic(critic_field="message", weight=0.5),
BinaryCritic(critic_field="user_name", weight=0.6),
SimilarityCritic(critic_field="message", weight=0.4),
],
)
@ -102,8 +102,8 @@ def slack_eval_suite() -> EvalSuite:
)
],
critics=[
BinaryCritic(critic_field="channel_name", weight=0.5),
SimilarityCritic(critic_field="message", weight=0.5),
BinaryCritic(critic_field="channel_name", weight=0.6),
SimilarityCritic(critic_field="message", weight=0.4),
],
)
@ -165,8 +165,10 @@ def slack_eval_suite() -> EvalSuite:
),
],
critics=[
SimilarityCritic(critic_field="user_name", weight=0.4),
SimilarityCritic(critic_field="message", weight=0.6),
SimilarityCritic(critic_field="user_name", weight=0.6),
SimilarityCritic(
critic_field="message", weight=0.4, similarity_threshold=0.7
),
],
)