arcade-mcp/toolkits/postgres/arcade_postgres/database_engine.py
Evan Tahler 4144a42392
Make the postgres toolikit better (#509)
Per the discussions in the [blog
post](https://docs.google.com/document/d/1wZi1yKWKOyg1dpueA2eBvTtYlUjqqp7e_M1nuAoF8NY/edit?tab=t.0),
let's update the postgres toolkit!

<img width="2141" height="1364" alt="Screenshot 2025-07-22 at 5 28
06 PM"
src="https://github.com/user-attachments/assets/1b6a5e0a-9429-4c16-9a0c-ac36ea520bea"
/>
2025-07-23 15:28:57 -07:00

180 lines
6.4 KiB
Python

from typing import Any, ClassVar
from urllib.parse import urlparse
from arcade_tdk.errors import RetryableToolError
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
MAX_ROWS_RETURNED = 1000
TEST_QUERY = "SELECT 1"
class DatabaseEngine:
_instance: ClassVar[None] = None
_engines: ClassVar[dict[str, AsyncEngine]] = {}
@classmethod
async def get_instance(cls, connection_string: str) -> AsyncEngine:
parsed_url = urlparse(connection_string)
# TODO: something strange with sslmode= and friends
# query_params = parse_qs(parsed_url.query)
# query_params = {
# k: v[0] for k, v in query_params.items()
# } # assume one value allowed for each query param
async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}"
key = f"{async_connection_string}"
if key not in cls._engines:
cls._engines[key] = create_async_engine(async_connection_string)
# try a simple query to see if the connection is valid
try:
async with cls._engines[key].connect() as connection:
await connection.execute(text(TEST_QUERY))
return cls._engines[key]
except Exception:
await cls._engines[key].dispose()
# try again
try:
async with cls._engines[key].connect() as connection:
await connection.execute(text(TEST_QUERY))
return cls._engines[key]
except Exception as e:
raise RetryableToolError(
f"Connection failed: {e}",
developer_message="Connection to postgres failed.",
additional_prompt_content="Check the connection string and try again.",
) from e
@classmethod
async def get_engine(cls, connection_string: str) -> Any:
engine = await cls.get_instance(connection_string)
class ConnectionContextManager:
def __init__(self, engine: AsyncEngine) -> None:
self.engine = engine
async def __aenter__(self) -> AsyncEngine:
return self.engine
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# Connection cleanup is handled by the async context manager
pass
return ConnectionContextManager(engine)
@classmethod
async def cleanup(cls) -> None:
"""Clean up all cached engines. Call this when shutting down."""
for engine in cls._engines.values():
await engine.dispose()
cls._engines.clear()
@classmethod
def clear_cache(cls) -> None:
"""Clear the engine cache without disposing engines. Use with caution."""
cls._engines.clear()
@classmethod
def sanitize_query( # noqa: C901
cls,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> tuple[str, dict[str, Any]]:
# Remove the leading keywords from the clauses if they are present
if select_clause.strip().split(" ")[0].upper() == "SELECT":
select_clause = select_clause.strip()[6:]
if from_clause.strip().split(" ")[0].upper() == "FROM":
from_clause = from_clause.strip()[4:]
if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN":
join_clause = join_clause.strip()[4:]
if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE":
where_clause = where_clause.strip()[5:]
if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY":
group_by_clause = group_by_clause.strip()[8:]
if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY":
order_by_clause = order_by_clause.strip()[8:]
if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING":
having_clause = having_clause.strip()[6:]
first_select_word = select_clause.strip().split(" ")[0].upper()
if first_select_word in [
"INSERT",
"UPDATE",
"DELETE",
"CREATE",
"ALTER",
"DROP",
"TRUNCATE",
"REINDEX",
"VACUUM",
"ANALYZE",
"COMMENT",
]:
raise RetryableToolError(
"Only SELECT queries are allowed.",
)
if select_clause.strip() == "*":
raise RetryableToolError(
"Do not use * in the select clause. Use a comma separated list of columns you wish to return.",
)
if limit > MAX_ROWS_RETURNED:
raise RetryableToolError(
f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.",
)
if offset < 0:
raise RetryableToolError(
"Offset must be greater than or equal to 0.",
developer_message="Offset must be greater than or equal to 0.",
)
if limit <= 0:
raise RetryableToolError(
"Limit must be greater than 0.",
developer_message="Limit must be greater than 0.",
)
# Build query with identifiers directly interpolated, but use parameters for values
parts = []
if with_clause:
parts.append(f"WITH {with_clause}")
parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608
if join_clause:
parts.append(f"JOIN {join_clause}")
if where_clause:
parts.append(f"WHERE {where_clause}")
if group_by_clause:
parts.append(f"GROUP BY {group_by_clause}")
if having_clause:
parts.append(f"HAVING {having_clause}")
if order_by_clause:
parts.append(f"ORDER BY {order_by_clause}")
parts.append("LIMIT :limit OFFSET :offset")
query = " ".join(parts)
# Only use parameters for values, not identifiers
parameters = {
"limit": limit,
"offset": offset,
}
return query, parameters