From 40cdf2018db1845ace1bae31cf288b9b0b8e272b Mon Sep 17 00:00:00 2001 From: Evan Tahler Date: Fri, 11 Jul 2025 17:30:40 -0700 Subject: [PATCH] Postgres Database Toolkit (#459) Adds an example of a good "general case" SQL tool: * enforces read-only mode * hints to the LLM to discover the tables and schemas for the tables it needs before any query * uses RetryableToolErrors to hint to the LLM about what to do next Docs: https://github.com/ArcadeAI/docs/pull/345 For testing, `TEST_POSTGRES_DATABASE_CONNECTION_STRING` has been set in the repo (from Neon). details in 1 password. 464977013-49aff5e5-e301-4ca0-83b5-3ea742db2283 --- .github/workflows/test-toolkits.yml | 2 + toolkits/postgres/Makefile | 53 ++++++ toolkits/postgres/arcade_postgres/__init__.py | 0 .../arcade_postgres/database_engine.py | 104 ++++++++++ .../arcade_postgres/tools/__init__.py | 0 .../arcade_postgres/tools/postgres.py | 178 ++++++++++++++++++ toolkits/postgres/evals/eval_postgres.py | 94 +++++++++ toolkits/postgres/pyproject.toml | 65 +++++++ toolkits/postgres/tests/__init__.py | 0 toolkits/postgres/tests/dump.sql | 114 +++++++++++ toolkits/postgres/tests/test_postgres.py | 119 ++++++++++++ 11 files changed, 729 insertions(+) create mode 100644 toolkits/postgres/Makefile create mode 100644 toolkits/postgres/arcade_postgres/__init__.py create mode 100644 toolkits/postgres/arcade_postgres/database_engine.py create mode 100644 toolkits/postgres/arcade_postgres/tools/__init__.py create mode 100644 toolkits/postgres/arcade_postgres/tools/postgres.py create mode 100644 toolkits/postgres/evals/eval_postgres.py create mode 100644 toolkits/postgres/pyproject.toml create mode 100644 toolkits/postgres/tests/__init__.py create mode 100644 toolkits/postgres/tests/dump.sql create mode 100644 toolkits/postgres/tests/test_postgres.py diff --git a/.github/workflows/test-toolkits.yml b/.github/workflows/test-toolkits.yml index 8f085092..3cefacf0 100644 --- a/.github/workflows/test-toolkits.yml +++ b/.github/workflows/test-toolkits.yml @@ -50,6 +50,8 @@ jobs: - name: Test toolkit working-directory: toolkits/${{ matrix.toolkit }} + env: + TEST_POSTGRES_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_POSTGRES_DATABASE_CONNECTION_STRING }} # TODO: dynamically only load the `TEST_${{ matrix.toolkit }}_DATABASE_CONNECTION_STRING secret` run: | # Run pytest and capture exit code uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$? diff --git a/toolkits/postgres/Makefile b/toolkits/postgres/Makefile new file mode 100644 index 00000000..7e2c686e --- /dev/null +++ b/toolkits/postgres/Makefile @@ -0,0 +1,53 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @uv run pre-commit install + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @uv run pre-commit install + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + coverage report + @echo "Generating coverage report" + coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --bump patch + +.PHONY: check +check: ## Run code quality tools. + @echo "🚀 Linting code: Running pre-commit" + @uv run pre-commit run -a + @echo "🚀 Static type checking: Running mypy" + @uv run mypy --config-file=pyproject.toml diff --git a/toolkits/postgres/arcade_postgres/__init__.py b/toolkits/postgres/arcade_postgres/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/postgres/arcade_postgres/database_engine.py b/toolkits/postgres/arcade_postgres/database_engine.py new file mode 100644 index 00000000..3a3ac88f --- /dev/null +++ b/toolkits/postgres/arcade_postgres/database_engine.py @@ -0,0 +1,104 @@ +from typing import Any, ClassVar +from urllib.parse import urlparse + +from arcade_tdk.errors import RetryableToolError +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +MAX_ROWS_RETURNED = 1000 +TEST_QUERY = "SELECT 1" + + +class DatabaseEngine: + _instance: ClassVar[None] = None + _engines: ClassVar[dict[str, AsyncEngine]] = {} + + @classmethod + async def get_instance(cls, connection_string: str) -> AsyncEngine: + parsed_url = urlparse(connection_string) + + # TODO: something strange with sslmode= and friends + # query_params = parse_qs(parsed_url.query) + # query_params = { + # k: v[0] for k, v in query_params.items() + # } # assume one value allowed for each query param + + async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}" + key = f"{async_connection_string}" + if key not in cls._engines: + cls._engines[key] = create_async_engine(async_connection_string) + + # try a simple query to see if the connection is valid + try: + async with cls._engines[key].connect() as connection: + await connection.execute(text(TEST_QUERY)) + return cls._engines[key] + except Exception: + await cls._engines[key].dispose() + + # try again + try: + async with cls._engines[key].connect() as connection: + await connection.execute(text(TEST_QUERY)) + return cls._engines[key] + except Exception as e: + raise RetryableToolError( + f"Connection failed: {e}", + developer_message="Connection to postgres failed.", + additional_prompt_content="Check the connection string and try again.", + ) from e + + @classmethod + async def get_engine(cls, connection_string: str) -> Any: + engine = await cls.get_instance(connection_string) + + class ConnectionContextManager: + def __init__(self, engine: AsyncEngine) -> None: + self.engine = engine + + async def __aenter__(self) -> AsyncEngine: + return self.engine + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + # Connection cleanup is handled by the async context manager + pass + + return ConnectionContextManager(engine) + + @classmethod + async def cleanup(cls) -> None: + """Clean up all cached engines. Call this when shutting down.""" + for engine in cls._engines.values(): + await engine.dispose() + cls._engines.clear() + + @classmethod + def clear_cache(cls) -> None: + """Clear the engine cache without disposing engines. Use with caution.""" + cls._engines.clear() + + @classmethod + def sanitize_query(cls, query: str) -> str: + """ + Sanitize a query to not break our read-only session. + THIS IS REALLY UNSAFE AND SHOULD NOT BE USED IN PRODUCTION. USE A DATABASE CONNECTION WITH A READ-ONLY USER AND PREPARE STATEMENTS. + There are also valid reasons for the ";" character, and this prevents that. + """ + + parts = query.split(";") + if len(parts) > 1: + raise RetryableToolError( + "Multiple statements are not allowed in a single query.", + developer_message="Multiple statements are not allowed in a single query.", + additional_prompt_content="Split your query into multiple queries and try again.", + ) + + words = parts[0].split(" ") + if words[0].upper().strip() != "SELECT": + raise RetryableToolError( + "Only SELECT queries are allowed.", + developer_message="Only SELECT queries are allowed.", + additional_prompt_content="Use the and tools to discover the tables and try again.", + ) + + return f"{query}" diff --git a/toolkits/postgres/arcade_postgres/tools/__init__.py b/toolkits/postgres/arcade_postgres/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/postgres/arcade_postgres/tools/postgres.py b/toolkits/postgres/arcade_postgres/tools/postgres.py new file mode 100644 index 00000000..02a445ae --- /dev/null +++ b/toolkits/postgres/arcade_postgres/tools/postgres.py @@ -0,0 +1,178 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from arcade_tdk.errors import RetryableToolError +from sqlalchemy import inspect, text +from sqlalchemy.ext.asyncio import AsyncEngine + +from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine + + +@tool(requires_secrets=["DATABASE_CONNECTION_STRING"]) +async def discover_schemas( + context: ToolContext, +) -> list[str]: + """Discover all the schemas in the postgres database.""" + async with await DatabaseEngine.get_engine( + context.get_secret("DATABASE_CONNECTION_STRING") + ) as engine: + schemas = await _get_schemas(engine) + return schemas + + +@tool(requires_secrets=["DATABASE_CONNECTION_STRING"]) +async def discover_tables( + context: ToolContext, + schema_name: Annotated[ + str, "The database schema to discover tables in (default value: 'public')" + ] = "public", +) -> list[str]: + """Discover all the tables in the postgres database when the list of tables is not known. + + THIS TOOL SHOULD ALWAYS BE USED BEFORE ANY OTHER TOOL THAT REQUIRES A TABLE NAME. + """ + async with await DatabaseEngine.get_engine( + context.get_secret("DATABASE_CONNECTION_STRING") + ) as engine: + tables = await _get_tables(engine, schema_name) + return tables + + +@tool(requires_secrets=["DATABASE_CONNECTION_STRING"]) +async def get_table_schema( + context: ToolContext, + schema_name: Annotated[str, "The database schema to get the table schema of"], + table_name: Annotated[str, "The table to get the schema of"], +) -> list[str]: + """ + Get the schema/structure of a postgres table in the postgres database when the schema is not known, and the name of the table is provided. + + THIS TOOL SHOULD ALWAYS BE USED BEFORE EXECUTING ANY QUERY. ALL TABLES IN THE QUERY MUST BE DISCOVERED FIRST USING THE TOOL. + """ + async with await DatabaseEngine.get_engine( + context.get_secret("DATABASE_CONNECTION_STRING") + ) as engine: + return await _get_table_schema(engine, schema_name, table_name) + + +@tool(requires_secrets=["DATABASE_CONNECTION_STRING"]) +async def execute_query( + context: ToolContext, + query: Annotated[str, "The postgres SQL query to execute. Only SELECT queries are allowed."], +) -> list[str]: + """ + You have a connection to a postgres database. + Execute a query and return the results against the postgres database. + + ONLY USE THIS TOOL IF YOU HAVE ALREADY LOADED THE SCHEMA OF THE TABLES YOU NEED TO QUERY. USE THE TOOL TO LOAD THE SCHEMA IF NOT ALREADY KNOWN. + + When running queries, follow these rules which will help avoid errors: + * Always use case-insensitive queries to match strings in the query. + * Always trim strings in the query. + * Prefer LIKE queries over direct string matches or regex queries. + * Only join on columns that are indexed or the primary key. Do not join on arbitrary columns. + + Only SELECT queries are allowed. Do not use INSERT, UPDATE, DELETE, or other DML statements. This tool will reject them. + + Unless otherwise specified, ensure that query has a LIMIT of 100 for all results. This tool will enforce that no more than 1000 rows are returned at maximum. + """ + async with await DatabaseEngine.get_engine( + context.get_secret("DATABASE_CONNECTION_STRING") + ) as engine: + try: + return await _execute_query(engine, query) + except Exception as e: + raise RetryableToolError( + f"Query failed: {e}", + developer_message=f"Query '{query}' failed.", + additional_prompt_content="Load the database schema or use the tool to discover the tables and try again.", + retry_after_ms=10, + ) from e + + +async def _get_schemas(engine: AsyncEngine) -> list[str]: + """Get all the schemas in the database""" + async with engine.connect() as conn: + + def get_schema_names(sync_conn: Any) -> list[str]: + return list(inspect(sync_conn).get_schema_names()) + + schemas: list[str] = await conn.run_sync(get_schema_names) + schemas = [schema for schema in schemas if schema != "information_schema"] + + return schemas + + +async def _get_tables(engine: AsyncEngine, schema_name: str) -> list[str]: + """Get all the tables in the database""" + async with engine.connect() as conn: + + def get_schema_names(sync_conn: Any) -> list[str]: + return list(inspect(sync_conn).get_schema_names()) + + schemas: list[str] = await conn.run_sync(get_schema_names) + tables = [] + for schema in schemas: + if schema == schema_name: + + def get_table_names(sync_conn: Any, s: str = schema) -> list[str]: + return list(inspect(sync_conn).get_table_names(schema=s)) + + these_tables = await conn.run_sync(get_table_names) + tables.extend(these_tables) + return tables + + +async def _get_table_schema(engine: AsyncEngine, schema_name: str, table_name: str) -> list[str]: + """Get the schema of a table""" + async with engine.connect() as connection: + + def get_columns(sync_conn: Any, t: str = table_name, s: str = schema_name) -> list[Any]: + return list(inspect(sync_conn).get_columns(t, s)) + + columns_table = await connection.run_sync(get_columns) + + # Get primary key information + pk_constraint = await connection.run_sync( + lambda sync_conn: inspect(sync_conn).get_pk_constraint(table_name, schema_name) + ) + primary_keys = set(pk_constraint.get("constrained_columns", [])) + + # Get index information + indexes = await connection.run_sync( + lambda sync_conn: inspect(sync_conn).get_indexes(table_name, schema_name) + ) + indexed_columns = set() + for index in indexes: + indexed_columns.update(index.get("column_names", [])) + + results = [] + for column in columns_table: + column_name = column["name"] + column_type = column["type"].python_type.__name__ + + # Build column description + description = f"{column_name}: {column_type}" + + # Add primary key indicator + if column_name in primary_keys: + description += " (PRIMARY KEY)" + + # Add index indicator + if column_name in indexed_columns: + description += " (INDEXED)" + + results.append(description) + + return results[:MAX_ROWS_RETURNED] + + +async def _execute_query( + engine: AsyncEngine, query: str, params: dict[str, Any] | None = None +) -> list[str]: + """Execute a query and return the results.""" + async with engine.connect() as connection: + result = await connection.execute(text(DatabaseEngine.sanitize_query(query)), params) + rows = result.fetchall() + results = [str(row) for row in rows] + return results[:MAX_ROWS_RETURNED] diff --git a/toolkits/postgres/evals/eval_postgres.py b/toolkits/postgres/evals/eval_postgres.py new file mode 100644 index 00000000..ef2517e1 --- /dev/null +++ b/toolkits/postgres/evals/eval_postgres.py @@ -0,0 +1,94 @@ +import arcade_postgres +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_postgres.tools.postgres import ( + discover_tables, + execute_query, + get_table_schema, +) +from arcade_tdk import ToolCatalog + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.85, + warn_threshold=0.95, +) + + +catalog = ToolCatalog() +catalog.add_module(arcade_postgres) + + +@tool_eval() +def sql_eval_suite() -> EvalSuite: + suite = EvalSuite( + name="sql Tools Evaluation", + system_message=( + "You are an AI assistant with access to sql tools. " + "Use them to help the user with their tasks." + ), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Get user by id (schema known)", + user_message="Tell me the name and email of user #1 in my database. The table 'users' has the following schema: id: int, name: str, email: str, password_hash: str, created_at: datetime, updated_at: datetime", + expected_tool_calls=[ + ExpectedToolCall( + func=execute_query, args={"query": "SELECT name, email FROM users WHERE id = 1"} + ) + ], + rubric=rubric, + critics=[SimilarityCritic(critic_field="query", weight=1.0)], + ) + + suite.add_case( + name="Discover tables", + user_message="What tables are in my database?", + expected_tool_calls=[ + ExpectedToolCall(func=discover_tables, args={}), + ], + rubric=rubric, + ) + + suite.add_case( + name="Get table schema (schema provided)", + user_message="What columns are in the table 'public.users' in my database?", + expected_tool_calls=[ + ExpectedToolCall( + func=get_table_schema, args={"schema_name": "public", "table_name": "users"} + ), + ], + rubric=rubric, + critics=[ + BinaryCritic(critic_field="schema_name", weight=0.5), + BinaryCritic(critic_field="table_name", weight=0.5), + ], + ) + + suite.add_case( + name="Get table schema (schema not provided)", + user_message="What columns are in the table 'users' in my database?", + additional_messages=[ + {"role": "user", "content": "When not provided, the schema is 'public'."} + ], + expected_tool_calls=[ + ExpectedToolCall( + func=get_table_schema, args={"schema_name": "public", "table_name": "users"} + ), + ], + rubric=rubric, + critics=[ + BinaryCritic(critic_field="schema_name", weight=0.5), + BinaryCritic(critic_field="table_name", weight=0.5), + ], + ) + + return suite diff --git a/toolkits/postgres/pyproject.toml b/toolkits/postgres/pyproject.toml new file mode 100644 index 00000000..cd4c055e --- /dev/null +++ b/toolkits/postgres/pyproject.toml @@ -0,0 +1,65 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_postgres" +version = "0.1.0" +description = "Tools to query and explore a postgres database" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "psycopg2-binary>=2.9.10", + "pydantic>=2.11.7", + "sqlalchemy>=2.0.41", + "psycopg2-binary>=2.9.10", + "asyncpg>=0.30.0", + "greenlet>=3.2.3", +] +[[project.authors]] +name = "evantahler" +email = "support@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_postgres/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_postgres",] diff --git a/toolkits/postgres/tests/__init__.py b/toolkits/postgres/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/postgres/tests/dump.sql b/toolkits/postgres/tests/dump.sql new file mode 100644 index 00000000..ef16710d --- /dev/null +++ b/toolkits/postgres/tests/dump.sql @@ -0,0 +1,114 @@ +DROP TABLE IF EXISTS "public"."messages"; +-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup. +-- Sequence and defined type +CREATE SEQUENCE IF NOT EXISTS messages_id_seq; +-- Table Definition +CREATE TABLE "public"."messages" ( + "id" int4 NOT NULL DEFAULT nextval('messages_id_seq'::regclass), + "body" text NOT NULL, + "user_id" int4 NOT NULL, + "created_at" timestamp NOT NULL DEFAULT now(), + "updated_at" timestamp NOT NULL DEFAULT now(), + PRIMARY KEY ("id") +); +DROP TABLE IF EXISTS "public"."users"; +-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup. +-- Sequence and defined type +CREATE SEQUENCE IF NOT EXISTS users_id_seq; +-- Table Definition +CREATE TABLE "public"."users" ( + "id" int4 NOT NULL DEFAULT nextval('users_id_seq'::regclass), + "name" varchar(256) NOT NULL, + "email" text NOT NULL, + "password_hash" text NOT NULL, + "created_at" timestamp NOT NULL DEFAULT now(), + "updated_at" timestamp NOT NULL DEFAULT now(), + "status" varchar, + PRIMARY KEY ("id") +); +INSERT INTO "public"."messages" ( + "id", + "body", + "user_id", + "created_at", + "updated_at" + ) +VALUES ( + 1, + 'Evan says hello', + 3, + '2025-04-10 17:21:05.504468', + '2025-04-10 17:21:05.504468' + ), + ( + 5100, + 'Hello! The current time is 2025-01-13T14:38:39.204Z', + 12, + '2025-01-13 06:38:39.210897', + '2025-01-13 06:38:39.210897' + ), + ( + 5101, + 'Hello! The current time is 2025-01-13T14:55:32.560Z', + 12, + '2025-01-13 06:55:32.56934', + '2025-01-13 06:55:32.56934' + ), + ( + 5102, + 'Hello! The current time is 2025-01-13T15:00:37.250Z', + 12, + '2025-01-13 07:00:37.261816', + '2025-01-13 07:00:37.261816' + ), + ( + 5319, + 'Hello! The current time is 2025-01-14T07:17:07.115Z', + 12, + '2025-01-13 23:17:07.123393', + '2025-01-13 23:17:07.123393' + ); +INSERT INTO "public"."users" ( + "id", + "name", + "email", + "password_hash", + "created_at", + "updated_at", + "status" + ) +VALUES ( + 1, + 'Mario', + 'mario@example.com', + '$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E', + '2024-09-01 20:49:38.759432', + '2024-09-02 03:49:39.927', + 'active' + ), + ( + 3, + 'Evan', + 'evantahler@gmail.com', + '$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY', + '2024-09-02 17:49:23.377425', + '2024-09-02 17:49:23.377425', + 'active' + ), + ( + 12, + 'Admin', + 'admin@arcade.dev', + '$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo', + '2024-10-13 15:01:30.792909', + '2024-10-13 15:01:30.792909', + 'inactive' + ); +ALTER TABLE "public"."messages" +ADD FOREIGN KEY ("user_id") REFERENCES "public"."users"("id"); +-- set pk to 13 +ALTER SEQUENCE users_id_seq RESTART WITH 13; +-- Indices +CREATE UNIQUE INDEX name_idx ON public.users USING btree (name); +CREATE UNIQUE INDEX email_idx ON public.users USING btree (email); +CREATE UNIQUE INDEX users_email_unique ON public.users USING btree (email); diff --git a/toolkits/postgres/tests/test_postgres.py b/toolkits/postgres/tests/test_postgres.py new file mode 100644 index 00000000..04b4acdd --- /dev/null +++ b/toolkits/postgres/tests/test_postgres.py @@ -0,0 +1,119 @@ +import os +from os import environ + +import pytest +import pytest_asyncio +from arcade_postgres.tools.postgres import ( + DatabaseEngine, + discover_schemas, + discover_tables, + execute_query, + get_table_schema, +) +from arcade_tdk import ToolContext, ToolSecretItem +from arcade_tdk.errors import RetryableToolError +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine + +DATABASE_CONNECTION_STRING = ( + environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING") + or "postgresql://evan@localhost:5432/postgres" +) + + +@pytest.fixture +def mock_context(): + context = ToolContext() + context.secrets = [] + context.secrets.append( + ToolSecretItem(key="DATABASE_CONNECTION_STRING", value=DATABASE_CONNECTION_STRING) + ) + + return context + + +# before the tests, restore the database from the dump +@pytest_asyncio.fixture(autouse=True) +async def restore_database(): + with open(f"{os.path.dirname(__file__)}/dump.sql") as f: + engine = create_async_engine( + DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split("?")[0] + ) + async with engine.connect() as c: + queries = f.read().split(";") + await c.execute(text("BEGIN")) + for query in queries: + if query.strip(): + await c.execute(text(query)) + await c.commit() + await engine.dispose() + + +@pytest_asyncio.fixture(autouse=True) +async def cleanup_engines(): + """Clean up database engines after each test to prevent connection leaks.""" + yield + # Clean up all cached engines after each test + await DatabaseEngine.cleanup() + + +@pytest.mark.asyncio +async def test_discover_schemas(mock_context) -> None: + assert await discover_schemas(mock_context) == ["public"] + + +@pytest.mark.asyncio +async def test_discover_tables(mock_context) -> None: + assert await discover_tables(mock_context) == ["users", "messages"] + + +@pytest.mark.asyncio +async def test_get_table_schema(mock_context) -> None: + assert await get_table_schema(mock_context, "public", "users") == [ + "id: int (PRIMARY KEY)", + "name: str (INDEXED)", + "email: str (INDEXED)", + "password_hash: str", + "created_at: datetime", + "updated_at: datetime", + "status: str", + ] + + assert await get_table_schema(mock_context, "public", "messages") == [ + "id: int (PRIMARY KEY)", + "body: str", + "user_id: int", + "created_at: datetime", + "updated_at: datetime", + ] + + +@pytest.mark.asyncio +async def test_execute_query(mock_context) -> None: + assert await execute_query(mock_context, "SELECT id, name, email FROM users WHERE id = 1") == [ + "(1, 'Mario', 'mario@example.com')" + ] + + +@pytest.mark.asyncio +async def test_execute_query_with_no_results(mock_context) -> None: + # does not raise an error + assert await execute_query(mock_context, "SELECT * FROM users WHERE id = 9999999999") == [] + + +@pytest.mark.asyncio +async def test_execute_query_with_problem(mock_context) -> None: + # 'foo' is not a valid id + with pytest.raises(RetryableToolError) as e: + await execute_query(mock_context, "SELECT * FROM users WHERE id = 'foo'") + assert "invalid input syntax" in str(e.value) + + +@pytest.mark.asyncio +async def test_execute_query_rejects_non_select(mock_context) -> None: + with pytest.raises(RetryableToolError) as e: + await execute_query( + mock_context, + "INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')", + ) + assert "Only SELECT queries are allowed" in str(e.value)