<!-- CURSOR_SUMMARY --> > [!NOTE] > **Medium Risk** > Touches multiple toolkits’ runtime entrypoints and context/error/auth plumbing, so breakage risk is mainly around invocation/packaging and tool execution wiring rather than business logic. > > **Overview** > Migrates the BrightData, ClickHouse, LinkedIn, Math, MongoDB, Postgres, and Zendesk OSS toolkits from `arcade-tdk` to `arcade-mcp-server` APIs by updating tool decorators, `Context` types, auth classes, and exception imports. > > Adds per-toolkit `__main__.py` files that construct an `MCPApp`, register module tools, and run via configurable transport/host/port; corresponding `pyproject.toml` updates bump versions, drop `arcade-tdk`/`arcade-serve` deps, and add `project.scripts` console entrypoints. > > Updates tests and eval suites to use `arcade_mcp_server.Context` (mocked) and switches eval `ToolCatalog` imports to `arcade_core`. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit 9b3e31acb4b35e1d72efd47e2d279c5b19e3ecb0. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY -->
180 lines
6.4 KiB
Python
180 lines
6.4 KiB
Python
from typing import Any, ClassVar
|
|
from urllib.parse import urlparse
|
|
|
|
from arcade_mcp_server.exceptions import RetryableToolError
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|
|
|
MAX_ROWS_RETURNED = 1000
|
|
TEST_QUERY = "SELECT 1"
|
|
|
|
|
|
class DatabaseEngine:
|
|
_instance: ClassVar[None] = None
|
|
_engines: ClassVar[dict[str, AsyncEngine]] = {}
|
|
|
|
@classmethod
|
|
async def get_instance(cls, connection_string: str) -> AsyncEngine:
|
|
parsed_url = urlparse(connection_string)
|
|
|
|
# TODO: something strange with sslmode= and friends
|
|
# query_params = parse_qs(parsed_url.query)
|
|
# query_params = {
|
|
# k: v[0] for k, v in query_params.items()
|
|
# } # assume one value allowed for each query param
|
|
|
|
async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}"
|
|
key = f"{async_connection_string}"
|
|
if key not in cls._engines:
|
|
cls._engines[key] = create_async_engine(async_connection_string)
|
|
|
|
# try a simple query to see if the connection is valid
|
|
try:
|
|
async with cls._engines[key].connect() as connection:
|
|
await connection.execute(text(TEST_QUERY))
|
|
return cls._engines[key]
|
|
except Exception:
|
|
await cls._engines[key].dispose()
|
|
|
|
# try again
|
|
try:
|
|
async with cls._engines[key].connect() as connection:
|
|
await connection.execute(text(TEST_QUERY))
|
|
return cls._engines[key]
|
|
except Exception as e:
|
|
raise RetryableToolError(
|
|
f"Connection failed: {e}",
|
|
developer_message="Connection to postgres failed.",
|
|
additional_prompt_content="Check the connection string and try again.",
|
|
) from e
|
|
|
|
@classmethod
|
|
async def get_engine(cls, connection_string: str) -> Any:
|
|
engine = await cls.get_instance(connection_string)
|
|
|
|
class ConnectionContextManager:
|
|
def __init__(self, engine: AsyncEngine) -> None:
|
|
self.engine = engine
|
|
|
|
async def __aenter__(self) -> AsyncEngine:
|
|
return self.engine
|
|
|
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
# Connection cleanup is handled by the async context manager
|
|
pass
|
|
|
|
return ConnectionContextManager(engine)
|
|
|
|
@classmethod
|
|
async def cleanup(cls) -> None:
|
|
"""Clean up all cached engines. Call this when shutting down."""
|
|
for engine in cls._engines.values():
|
|
await engine.dispose()
|
|
cls._engines.clear()
|
|
|
|
@classmethod
|
|
def clear_cache(cls) -> None:
|
|
"""Clear the engine cache without disposing engines. Use with caution."""
|
|
cls._engines.clear()
|
|
|
|
@classmethod
|
|
def sanitize_query( # noqa: C901
|
|
cls,
|
|
select_clause: str,
|
|
from_clause: str,
|
|
limit: int,
|
|
offset: int,
|
|
join_clause: str | None,
|
|
where_clause: str | None,
|
|
having_clause: str | None,
|
|
group_by_clause: str | None,
|
|
order_by_clause: str | None,
|
|
with_clause: str | None,
|
|
) -> tuple[str, dict[str, Any]]:
|
|
# Remove the leading keywords from the clauses if they are present
|
|
if select_clause.strip().split(" ")[0].upper() == "SELECT":
|
|
select_clause = select_clause.strip()[6:]
|
|
|
|
if from_clause.strip().split(" ")[0].upper() == "FROM":
|
|
from_clause = from_clause.strip()[4:]
|
|
|
|
if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN":
|
|
join_clause = join_clause.strip()[4:]
|
|
|
|
if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE":
|
|
where_clause = where_clause.strip()[5:]
|
|
|
|
if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY":
|
|
group_by_clause = group_by_clause.strip()[8:]
|
|
|
|
if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY":
|
|
order_by_clause = order_by_clause.strip()[8:]
|
|
|
|
if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING":
|
|
having_clause = having_clause.strip()[6:]
|
|
|
|
first_select_word = select_clause.strip().split(" ")[0].upper()
|
|
if first_select_word in [
|
|
"INSERT",
|
|
"UPDATE",
|
|
"DELETE",
|
|
"CREATE",
|
|
"ALTER",
|
|
"DROP",
|
|
"TRUNCATE",
|
|
"REINDEX",
|
|
"VACUUM",
|
|
"ANALYZE",
|
|
"COMMENT",
|
|
]:
|
|
raise RetryableToolError(
|
|
"Only SELECT queries are allowed.",
|
|
)
|
|
|
|
if select_clause.strip() == "*":
|
|
raise RetryableToolError(
|
|
"Do not use * in the select clause. Use a comma separated list of columns you wish to return.",
|
|
)
|
|
|
|
if limit > MAX_ROWS_RETURNED:
|
|
raise RetryableToolError(
|
|
f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.",
|
|
)
|
|
|
|
if offset < 0:
|
|
raise RetryableToolError(
|
|
"Offset must be greater than or equal to 0.",
|
|
developer_message="Offset must be greater than or equal to 0.",
|
|
)
|
|
|
|
if limit <= 0:
|
|
raise RetryableToolError(
|
|
"Limit must be greater than 0.",
|
|
developer_message="Limit must be greater than 0.",
|
|
)
|
|
|
|
# Build query with identifiers directly interpolated, but use parameters for values
|
|
parts = []
|
|
if with_clause:
|
|
parts.append(f"WITH {with_clause}")
|
|
parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608
|
|
if join_clause:
|
|
parts.append(f"JOIN {join_clause}")
|
|
if where_clause:
|
|
parts.append(f"WHERE {where_clause}")
|
|
if group_by_clause:
|
|
parts.append(f"GROUP BY {group_by_clause}")
|
|
if having_clause:
|
|
parts.append(f"HAVING {having_clause}")
|
|
if order_by_clause:
|
|
parts.append(f"ORDER BY {order_by_clause}")
|
|
parts.append("LIMIT :limit OFFSET :offset")
|
|
query = " ".join(parts)
|
|
|
|
# Only use parameters for values, not identifiers
|
|
parameters = {
|
|
"limit": limit,
|
|
"offset": offset,
|
|
}
|
|
|
|
return query, parameters
|