arcade-mcp/toolkits/postgres/evals/eval_postgres.py
Eric Gustin ca11ec9fa3
Add toolkits (#514)
Add Linkedin, Zendesk, Math, and Postgres toolkits
2025-07-25 15:44:06 -07:00

94 lines
2.7 KiB
Python

import arcade_postgres
from arcade_evals import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
SimilarityCritic,
tool_eval,
)
from arcade_postgres.tools.postgres import (
discover_tables,
execute_query,
get_table_schema,
)
from arcade_tdk import ToolCatalog
# Evaluation rubric
rubric = EvalRubric(
fail_threshold=0.85,
warn_threshold=0.95,
)
catalog = ToolCatalog()
catalog.add_module(arcade_postgres)
@tool_eval()
def sql_eval_suite() -> EvalSuite:
suite = EvalSuite(
name="sql Tools Evaluation",
system_message=(
"You are an AI assistant with access to sql tools. "
"Use them to help the user with their tasks."
),
catalog=catalog,
rubric=rubric,
)
suite.add_case(
name="Get user by id (schema known)",
user_message="Tell me the name and email of user #1 in my database. The table 'users' has the following schema: id: int, name: str, email: str, password_hash: str, created_at: datetime, updated_at: datetime",
expected_tool_calls=[
ExpectedToolCall(
func=execute_query, args={"query": "SELECT name, email FROM users WHERE id = 1"}
)
],
rubric=rubric,
critics=[SimilarityCritic(critic_field="query", weight=1.0)],
)
suite.add_case(
name="Discover tables",
user_message="What tables are in my database?",
expected_tool_calls=[
ExpectedToolCall(func=discover_tables, args={}),
],
rubric=rubric,
)
suite.add_case(
name="Get table schema (schema provided)",
user_message="What columns are in the table 'public.users' in my database?",
expected_tool_calls=[
ExpectedToolCall(
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="schema_name", weight=0.5),
BinaryCritic(critic_field="table_name", weight=0.5),
],
)
suite.add_case(
name="Get table schema (schema not provided)",
user_message="What columns are in the table 'users' in my database?",
additional_messages=[
{"role": "user", "content": "When not provided, the schema is 'public'."}
],
expected_tool_calls=[
ExpectedToolCall(
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="schema_name", weight=0.5),
BinaryCritic(critic_field="table_name", weight=0.5),
],
)
return suite