arcade-mcp/toolkits/postgres/arcade_postgres/database_engine.py
Evan Tahler 40cdf2018d
Postgres Database Toolkit (#459)
Adds an example of a good "general case" SQL tool:
* enforces read-only mode
* hints to the LLM to discover the tables and schemas for the tables it
needs before any query
* uses RetryableToolErrors to hint to the LLM about what to do next

Docs: https://github.com/ArcadeAI/docs/pull/345

For testing, `TEST_POSTGRES_DATABASE_CONNECTION_STRING` has been set in
the repo (from Neon). details in 1 password.

<img width="1178" height="1091"
alt="464977013-49aff5e5-e301-4ca0-83b5-3ea742db2283"
src="https://github.com/user-attachments/assets/9344c27b-015d-4b91-907e-84f2e4193e16"
/>
2025-07-11 17:30:40 -07:00

104 lines
4 KiB
Python

from typing import Any, ClassVar
from urllib.parse import urlparse
from arcade_tdk.errors import RetryableToolError
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
MAX_ROWS_RETURNED = 1000
TEST_QUERY = "SELECT 1"
class DatabaseEngine:
_instance: ClassVar[None] = None
_engines: ClassVar[dict[str, AsyncEngine]] = {}
@classmethod
async def get_instance(cls, connection_string: str) -> AsyncEngine:
parsed_url = urlparse(connection_string)
# TODO: something strange with sslmode= and friends
# query_params = parse_qs(parsed_url.query)
# query_params = {
# k: v[0] for k, v in query_params.items()
# } # assume one value allowed for each query param
async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}"
key = f"{async_connection_string}"
if key not in cls._engines:
cls._engines[key] = create_async_engine(async_connection_string)
# try a simple query to see if the connection is valid
try:
async with cls._engines[key].connect() as connection:
await connection.execute(text(TEST_QUERY))
return cls._engines[key]
except Exception:
await cls._engines[key].dispose()
# try again
try:
async with cls._engines[key].connect() as connection:
await connection.execute(text(TEST_QUERY))
return cls._engines[key]
except Exception as e:
raise RetryableToolError(
f"Connection failed: {e}",
developer_message="Connection to postgres failed.",
additional_prompt_content="Check the connection string and try again.",
) from e
@classmethod
async def get_engine(cls, connection_string: str) -> Any:
engine = await cls.get_instance(connection_string)
class ConnectionContextManager:
def __init__(self, engine: AsyncEngine) -> None:
self.engine = engine
async def __aenter__(self) -> AsyncEngine:
return self.engine
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# Connection cleanup is handled by the async context manager
pass
return ConnectionContextManager(engine)
@classmethod
async def cleanup(cls) -> None:
"""Clean up all cached engines. Call this when shutting down."""
for engine in cls._engines.values():
await engine.dispose()
cls._engines.clear()
@classmethod
def clear_cache(cls) -> None:
"""Clear the engine cache without disposing engines. Use with caution."""
cls._engines.clear()
@classmethod
def sanitize_query(cls, query: str) -> str:
"""
Sanitize a query to not break our read-only session.
THIS IS REALLY UNSAFE AND SHOULD NOT BE USED IN PRODUCTION. USE A DATABASE CONNECTION WITH A READ-ONLY USER AND PREPARE STATEMENTS.
There are also valid reasons for the ";" character, and this prevents that.
"""
parts = query.split(";")
if len(parts) > 1:
raise RetryableToolError(
"Multiple statements are not allowed in a single query.",
developer_message="Multiple statements are not allowed in a single query.",
additional_prompt_content="Split your query into multiple queries and try again.",
)
words = parts[0].split(" ")
if words[0].upper().strip() != "SELECT":
raise RetryableToolError(
"Only SELECT queries are allowed.",
developer_message="Only SELECT queries are allowed.",
additional_prompt_content="Use the <DiscoverTables> and <GetTableSchema> tools to discover the tables and try again.",
)
return f"{query}"