arcade-mcp/toolkits/postgres/tests/test_postgres.py
Eric Gustin ca11ec9fa3
Add toolkits (#514)
Add Linkedin, Zendesk, Math, and Postgres toolkits
2025-07-25 15:44:06 -07:00

189 lines
5.4 KiB
Python

import os
from os import environ
import pytest
import pytest_asyncio
from arcade_postgres.tools.postgres import (
DatabaseEngine,
discover_schemas,
discover_tables,
execute_select_query,
get_table_schema,
)
from arcade_tdk import ToolContext, ToolSecretItem
from arcade_tdk.errors import RetryableToolError
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
DATABASE_CONNECTION_STRING = (
environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING")
or "postgresql://evan@localhost:5432/postgres"
)
@pytest.fixture
def mock_context():
context = ToolContext()
context.secrets = []
context.secrets.append(
ToolSecretItem(key="DATABASE_CONNECTION_STRING", value=DATABASE_CONNECTION_STRING)
)
return context
# before the tests, restore the database from the dump
@pytest_asyncio.fixture(autouse=True)
async def restore_database():
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
engine = create_async_engine(
DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split("?")[0]
)
async with engine.connect() as c:
queries = f.read().split(";")
await c.execute(text("BEGIN"))
for query in queries:
if query.strip():
await c.execute(text(query))
await c.commit()
await engine.dispose()
@pytest_asyncio.fixture(autouse=True)
async def cleanup_engines():
"""Clean up database engines after each test to prevent connection leaks."""
yield
# Clean up all cached engines after each test
await DatabaseEngine.cleanup()
@pytest.mark.asyncio
async def test_discover_schemas(mock_context) -> None:
assert await discover_schemas(mock_context) == ["public"]
@pytest.mark.asyncio
async def test_discover_tables(mock_context) -> None:
assert await discover_tables(mock_context) == ["users", "messages"]
@pytest.mark.asyncio
async def test_get_table_schema(mock_context) -> None:
assert await get_table_schema(mock_context, "public", "users") == [
"id: int (PRIMARY KEY)",
"name: str (INDEXED)",
"email: str (INDEXED)",
"password_hash: str",
"created_at: datetime",
"updated_at: datetime",
"status: str",
]
assert await get_table_schema(mock_context, "public", "messages") == [
"id: int (PRIMARY KEY)",
"body: str",
"user_id: int",
"created_at: datetime",
"updated_at: datetime",
]
@pytest.mark.asyncio
async def test_execute_select_query(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 1",
) == [
"(1, 'Alice', 'alice@example.com')",
]
assert await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
order_by_clause="id",
limit=1,
offset=1,
) == [
"(2, 'Bob', 'bob@example.com')",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_keywords(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="SELECT id, name, email",
from_clause="FROM users",
limit=1,
) == [
"(1, 'Alice', 'alice@example.com')",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_join(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="u.id, u.name, u.email, m.id, m.body",
from_clause="users u",
join_clause="messages m ON u.id = m.user_id",
limit=1,
) == [
"(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_group_by(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="u.name, COUNT(m.id) AS message_count",
from_clause="messages m",
join_clause="users u ON m.user_id = u.id",
group_by_clause="u.name",
order_by_clause="message_count DESC",
limit=2,
) == [
"('Evan', 13)",
"('Alice', 3)",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_no_results(mock_context) -> None:
# does not raise an error
assert (
await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 9999999999",
)
== []
)
@pytest.mark.asyncio
async def test_execute_select_query_with_problem(mock_context) -> None:
# 'foo' is not a valid id
with pytest.raises(RetryableToolError) as e:
await execute_select_query(
mock_context,
select_clause="*",
from_clause="users",
where_clause="id = 'foo'",
)
assert "Do not use * in the select clause" in str(e.value)
@pytest.mark.asyncio
async def test_execute_select_query_rejects_non_select(mock_context) -> None:
with pytest.raises(RetryableToolError) as e:
await execute_select_query(
mock_context,
select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
from_clause="users",
)
assert "Only SELECT queries are allowed" in str(e.value)