Adds an example of a good "general case" SQL tool: * enforces read-only mode * hints to the LLM to discover the tables and schemas for the tables it needs before any query * uses RetryableToolErrors to hint to the LLM about what to do next Docs: https://github.com/ArcadeAI/docs/pull/345 For testing, `TEST_POSTGRES_DATABASE_CONNECTION_STRING` has been set in the repo (from Neon). details in 1 password. <img width="1178" height="1091" alt="464977013-49aff5e5-e301-4ca0-83b5-3ea742db2283" src="https://github.com/user-attachments/assets/9344c27b-015d-4b91-907e-84f2e4193e16" />
119 lines
3.6 KiB
Python
119 lines
3.6 KiB
Python
import os
|
|
from os import environ
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from arcade_postgres.tools.postgres import (
|
|
DatabaseEngine,
|
|
discover_schemas,
|
|
discover_tables,
|
|
execute_query,
|
|
get_table_schema,
|
|
)
|
|
from arcade_tdk import ToolContext, ToolSecretItem
|
|
from arcade_tdk.errors import RetryableToolError
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
|
|
DATABASE_CONNECTION_STRING = (
|
|
environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING")
|
|
or "postgresql://evan@localhost:5432/postgres"
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_context():
|
|
context = ToolContext()
|
|
context.secrets = []
|
|
context.secrets.append(
|
|
ToolSecretItem(key="DATABASE_CONNECTION_STRING", value=DATABASE_CONNECTION_STRING)
|
|
)
|
|
|
|
return context
|
|
|
|
|
|
# before the tests, restore the database from the dump
|
|
@pytest_asyncio.fixture(autouse=True)
|
|
async def restore_database():
|
|
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
|
|
engine = create_async_engine(
|
|
DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split("?")[0]
|
|
)
|
|
async with engine.connect() as c:
|
|
queries = f.read().split(";")
|
|
await c.execute(text("BEGIN"))
|
|
for query in queries:
|
|
if query.strip():
|
|
await c.execute(text(query))
|
|
await c.commit()
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture(autouse=True)
|
|
async def cleanup_engines():
|
|
"""Clean up database engines after each test to prevent connection leaks."""
|
|
yield
|
|
# Clean up all cached engines after each test
|
|
await DatabaseEngine.cleanup()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_discover_schemas(mock_context) -> None:
|
|
assert await discover_schemas(mock_context) == ["public"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_discover_tables(mock_context) -> None:
|
|
assert await discover_tables(mock_context) == ["users", "messages"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_table_schema(mock_context) -> None:
|
|
assert await get_table_schema(mock_context, "public", "users") == [
|
|
"id: int (PRIMARY KEY)",
|
|
"name: str (INDEXED)",
|
|
"email: str (INDEXED)",
|
|
"password_hash: str",
|
|
"created_at: datetime",
|
|
"updated_at: datetime",
|
|
"status: str",
|
|
]
|
|
|
|
assert await get_table_schema(mock_context, "public", "messages") == [
|
|
"id: int (PRIMARY KEY)",
|
|
"body: str",
|
|
"user_id: int",
|
|
"created_at: datetime",
|
|
"updated_at: datetime",
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_query(mock_context) -> None:
|
|
assert await execute_query(mock_context, "SELECT id, name, email FROM users WHERE id = 1") == [
|
|
"(1, 'Mario', 'mario@example.com')"
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_query_with_no_results(mock_context) -> None:
|
|
# does not raise an error
|
|
assert await execute_query(mock_context, "SELECT * FROM users WHERE id = 9999999999") == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_query_with_problem(mock_context) -> None:
|
|
# 'foo' is not a valid id
|
|
with pytest.raises(RetryableToolError) as e:
|
|
await execute_query(mock_context, "SELECT * FROM users WHERE id = 'foo'")
|
|
assert "invalid input syntax" in str(e.value)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_query_rejects_non_select(mock_context) -> None:
|
|
with pytest.raises(RetryableToolError) as e:
|
|
await execute_query(
|
|
mock_context,
|
|
"INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
|
|
)
|
|
assert "Only SELECT queries are allowed" in str(e.value)
|