Postgres Database Toolkit (#459)
Adds an example of a good "general case" SQL tool: * enforces read-only mode * hints to the LLM to discover the tables and schemas for the tables it needs before any query * uses RetryableToolErrors to hint to the LLM about what to do next Docs: https://github.com/ArcadeAI/docs/pull/345 For testing, `TEST_POSTGRES_DATABASE_CONNECTION_STRING` has been set in the repo (from Neon). details in 1 password. <img width="1178" height="1091" alt="464977013-49aff5e5-e301-4ca0-83b5-3ea742db2283" src="https://github.com/user-attachments/assets/9344c27b-015d-4b91-907e-84f2e4193e16" />
This commit is contained in:
parent
1b0547090c
commit
40cdf2018d
11 changed files with 729 additions and 0 deletions
2
.github/workflows/test-toolkits.yml
vendored
2
.github/workflows/test-toolkits.yml
vendored
|
|
@ -50,6 +50,8 @@ jobs:
|
|||
|
||||
- name: Test toolkit
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
env:
|
||||
TEST_POSTGRES_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_POSTGRES_DATABASE_CONNECTION_STRING }} # TODO: dynamically only load the `TEST_${{ matrix.toolkit }}_DATABASE_CONNECTION_STRING secret`
|
||||
run: |
|
||||
# Run pytest and capture exit code
|
||||
uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$?
|
||||
|
|
|
|||
53
toolkits/postgres/Makefile
Normal file
53
toolkits/postgres/Makefile
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ github Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the uv environment and install all packages with dependencies
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras --no-sources
|
||||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: install-local
|
||||
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras
|
||||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
uv build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
coverage report
|
||||
@echo "Generating coverage report"
|
||||
coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
uv version --bump patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@echo "🚀 Linting code: Running pre-commit"
|
||||
@uv run pre-commit run -a
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@uv run mypy --config-file=pyproject.toml
|
||||
0
toolkits/postgres/arcade_postgres/__init__.py
Normal file
0
toolkits/postgres/arcade_postgres/__init__.py
Normal file
104
toolkits/postgres/arcade_postgres/database_engine.py
Normal file
104
toolkits/postgres/arcade_postgres/database_engine.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
from typing import Any, ClassVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from arcade_tdk.errors import RetryableToolError
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
|
||||
MAX_ROWS_RETURNED = 1000
|
||||
TEST_QUERY = "SELECT 1"
|
||||
|
||||
|
||||
class DatabaseEngine:
|
||||
_instance: ClassVar[None] = None
|
||||
_engines: ClassVar[dict[str, AsyncEngine]] = {}
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, connection_string: str) -> AsyncEngine:
|
||||
parsed_url = urlparse(connection_string)
|
||||
|
||||
# TODO: something strange with sslmode= and friends
|
||||
# query_params = parse_qs(parsed_url.query)
|
||||
# query_params = {
|
||||
# k: v[0] for k, v in query_params.items()
|
||||
# } # assume one value allowed for each query param
|
||||
|
||||
async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}"
|
||||
key = f"{async_connection_string}"
|
||||
if key not in cls._engines:
|
||||
cls._engines[key] = create_async_engine(async_connection_string)
|
||||
|
||||
# try a simple query to see if the connection is valid
|
||||
try:
|
||||
async with cls._engines[key].connect() as connection:
|
||||
await connection.execute(text(TEST_QUERY))
|
||||
return cls._engines[key]
|
||||
except Exception:
|
||||
await cls._engines[key].dispose()
|
||||
|
||||
# try again
|
||||
try:
|
||||
async with cls._engines[key].connect() as connection:
|
||||
await connection.execute(text(TEST_QUERY))
|
||||
return cls._engines[key]
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Connection failed: {e}",
|
||||
developer_message="Connection to postgres failed.",
|
||||
additional_prompt_content="Check the connection string and try again.",
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
async def get_engine(cls, connection_string: str) -> Any:
|
||||
engine = await cls.get_instance(connection_string)
|
||||
|
||||
class ConnectionContextManager:
|
||||
def __init__(self, engine: AsyncEngine) -> None:
|
||||
self.engine = engine
|
||||
|
||||
async def __aenter__(self) -> AsyncEngine:
|
||||
return self.engine
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
# Connection cleanup is handled by the async context manager
|
||||
pass
|
||||
|
||||
return ConnectionContextManager(engine)
|
||||
|
||||
@classmethod
|
||||
async def cleanup(cls) -> None:
|
||||
"""Clean up all cached engines. Call this when shutting down."""
|
||||
for engine in cls._engines.values():
|
||||
await engine.dispose()
|
||||
cls._engines.clear()
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the engine cache without disposing engines. Use with caution."""
|
||||
cls._engines.clear()
|
||||
|
||||
@classmethod
|
||||
def sanitize_query(cls, query: str) -> str:
|
||||
"""
|
||||
Sanitize a query to not break our read-only session.
|
||||
THIS IS REALLY UNSAFE AND SHOULD NOT BE USED IN PRODUCTION. USE A DATABASE CONNECTION WITH A READ-ONLY USER AND PREPARE STATEMENTS.
|
||||
There are also valid reasons for the ";" character, and this prevents that.
|
||||
"""
|
||||
|
||||
parts = query.split(";")
|
||||
if len(parts) > 1:
|
||||
raise RetryableToolError(
|
||||
"Multiple statements are not allowed in a single query.",
|
||||
developer_message="Multiple statements are not allowed in a single query.",
|
||||
additional_prompt_content="Split your query into multiple queries and try again.",
|
||||
)
|
||||
|
||||
words = parts[0].split(" ")
|
||||
if words[0].upper().strip() != "SELECT":
|
||||
raise RetryableToolError(
|
||||
"Only SELECT queries are allowed.",
|
||||
developer_message="Only SELECT queries are allowed.",
|
||||
additional_prompt_content="Use the <DiscoverTables> and <GetTableSchema> tools to discover the tables and try again.",
|
||||
)
|
||||
|
||||
return f"{query}"
|
||||
0
toolkits/postgres/arcade_postgres/tools/__init__.py
Normal file
0
toolkits/postgres/arcade_postgres/tools/__init__.py
Normal file
178
toolkits/postgres/arcade_postgres/tools/postgres.py
Normal file
178
toolkits/postgres/arcade_postgres/tools/postgres.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
from typing import Annotated, Any
|
||||
|
||||
from arcade_tdk import ToolContext, tool
|
||||
from arcade_tdk.errors import RetryableToolError
|
||||
from sqlalchemy import inspect, text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine
|
||||
|
||||
|
||||
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
|
||||
async def discover_schemas(
|
||||
context: ToolContext,
|
||||
) -> list[str]:
|
||||
"""Discover all the schemas in the postgres database."""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("DATABASE_CONNECTION_STRING")
|
||||
) as engine:
|
||||
schemas = await _get_schemas(engine)
|
||||
return schemas
|
||||
|
||||
|
||||
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
|
||||
async def discover_tables(
|
||||
context: ToolContext,
|
||||
schema_name: Annotated[
|
||||
str, "The database schema to discover tables in (default value: 'public')"
|
||||
] = "public",
|
||||
) -> list[str]:
|
||||
"""Discover all the tables in the postgres database when the list of tables is not known.
|
||||
|
||||
THIS TOOL SHOULD ALWAYS BE USED BEFORE ANY OTHER TOOL THAT REQUIRES A TABLE NAME.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("DATABASE_CONNECTION_STRING")
|
||||
) as engine:
|
||||
tables = await _get_tables(engine, schema_name)
|
||||
return tables
|
||||
|
||||
|
||||
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
|
||||
async def get_table_schema(
|
||||
context: ToolContext,
|
||||
schema_name: Annotated[str, "The database schema to get the table schema of"],
|
||||
table_name: Annotated[str, "The table to get the schema of"],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the schema/structure of a postgres table in the postgres database when the schema is not known, and the name of the table is provided.
|
||||
|
||||
THIS TOOL SHOULD ALWAYS BE USED BEFORE EXECUTING ANY QUERY. ALL TABLES IN THE QUERY MUST BE DISCOVERED FIRST USING THE <DiscoverTables> TOOL.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("DATABASE_CONNECTION_STRING")
|
||||
) as engine:
|
||||
return await _get_table_schema(engine, schema_name, table_name)
|
||||
|
||||
|
||||
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
|
||||
async def execute_query(
|
||||
context: ToolContext,
|
||||
query: Annotated[str, "The postgres SQL query to execute. Only SELECT queries are allowed."],
|
||||
) -> list[str]:
|
||||
"""
|
||||
You have a connection to a postgres database.
|
||||
Execute a query and return the results against the postgres database.
|
||||
|
||||
ONLY USE THIS TOOL IF YOU HAVE ALREADY LOADED THE SCHEMA OF THE TABLES YOU NEED TO QUERY. USE THE <GetTableSchema> TOOL TO LOAD THE SCHEMA IF NOT ALREADY KNOWN.
|
||||
|
||||
When running queries, follow these rules which will help avoid errors:
|
||||
* Always use case-insensitive queries to match strings in the query.
|
||||
* Always trim strings in the query.
|
||||
* Prefer LIKE queries over direct string matches or regex queries.
|
||||
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
|
||||
|
||||
Only SELECT queries are allowed. Do not use INSERT, UPDATE, DELETE, or other DML statements. This tool will reject them.
|
||||
|
||||
Unless otherwise specified, ensure that query has a LIMIT of 100 for all results. This tool will enforce that no more than 1000 rows are returned at maximum.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("DATABASE_CONNECTION_STRING")
|
||||
) as engine:
|
||||
try:
|
||||
return await _execute_query(engine, query)
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Query failed: {e}",
|
||||
developer_message=f"Query '{query}' failed.",
|
||||
additional_prompt_content="Load the database schema <GetTableSchema> or use the <DiscoverTables> tool to discover the tables and try again.",
|
||||
retry_after_ms=10,
|
||||
) from e
|
||||
|
||||
|
||||
async def _get_schemas(engine: AsyncEngine) -> list[str]:
|
||||
"""Get all the schemas in the database"""
|
||||
async with engine.connect() as conn:
|
||||
|
||||
def get_schema_names(sync_conn: Any) -> list[str]:
|
||||
return list(inspect(sync_conn).get_schema_names())
|
||||
|
||||
schemas: list[str] = await conn.run_sync(get_schema_names)
|
||||
schemas = [schema for schema in schemas if schema != "information_schema"]
|
||||
|
||||
return schemas
|
||||
|
||||
|
||||
async def _get_tables(engine: AsyncEngine, schema_name: str) -> list[str]:
|
||||
"""Get all the tables in the database"""
|
||||
async with engine.connect() as conn:
|
||||
|
||||
def get_schema_names(sync_conn: Any) -> list[str]:
|
||||
return list(inspect(sync_conn).get_schema_names())
|
||||
|
||||
schemas: list[str] = await conn.run_sync(get_schema_names)
|
||||
tables = []
|
||||
for schema in schemas:
|
||||
if schema == schema_name:
|
||||
|
||||
def get_table_names(sync_conn: Any, s: str = schema) -> list[str]:
|
||||
return list(inspect(sync_conn).get_table_names(schema=s))
|
||||
|
||||
these_tables = await conn.run_sync(get_table_names)
|
||||
tables.extend(these_tables)
|
||||
return tables
|
||||
|
||||
|
||||
async def _get_table_schema(engine: AsyncEngine, schema_name: str, table_name: str) -> list[str]:
|
||||
"""Get the schema of a table"""
|
||||
async with engine.connect() as connection:
|
||||
|
||||
def get_columns(sync_conn: Any, t: str = table_name, s: str = schema_name) -> list[Any]:
|
||||
return list(inspect(sync_conn).get_columns(t, s))
|
||||
|
||||
columns_table = await connection.run_sync(get_columns)
|
||||
|
||||
# Get primary key information
|
||||
pk_constraint = await connection.run_sync(
|
||||
lambda sync_conn: inspect(sync_conn).get_pk_constraint(table_name, schema_name)
|
||||
)
|
||||
primary_keys = set(pk_constraint.get("constrained_columns", []))
|
||||
|
||||
# Get index information
|
||||
indexes = await connection.run_sync(
|
||||
lambda sync_conn: inspect(sync_conn).get_indexes(table_name, schema_name)
|
||||
)
|
||||
indexed_columns = set()
|
||||
for index in indexes:
|
||||
indexed_columns.update(index.get("column_names", []))
|
||||
|
||||
results = []
|
||||
for column in columns_table:
|
||||
column_name = column["name"]
|
||||
column_type = column["type"].python_type.__name__
|
||||
|
||||
# Build column description
|
||||
description = f"{column_name}: {column_type}"
|
||||
|
||||
# Add primary key indicator
|
||||
if column_name in primary_keys:
|
||||
description += " (PRIMARY KEY)"
|
||||
|
||||
# Add index indicator
|
||||
if column_name in indexed_columns:
|
||||
description += " (INDEXED)"
|
||||
|
||||
results.append(description)
|
||||
|
||||
return results[:MAX_ROWS_RETURNED]
|
||||
|
||||
|
||||
async def _execute_query(
|
||||
engine: AsyncEngine, query: str, params: dict[str, Any] | None = None
|
||||
) -> list[str]:
|
||||
"""Execute a query and return the results."""
|
||||
async with engine.connect() as connection:
|
||||
result = await connection.execute(text(DatabaseEngine.sanitize_query(query)), params)
|
||||
rows = result.fetchall()
|
||||
results = [str(row) for row in rows]
|
||||
return results[:MAX_ROWS_RETURNED]
|
||||
94
toolkits/postgres/evals/eval_postgres.py
Normal file
94
toolkits/postgres/evals/eval_postgres.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import arcade_postgres
|
||||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_postgres.tools.postgres import (
|
||||
discover_tables,
|
||||
execute_query,
|
||||
get_table_schema,
|
||||
)
|
||||
from arcade_tdk import ToolCatalog
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.85,
|
||||
warn_threshold=0.95,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_module(arcade_postgres)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def sql_eval_suite() -> EvalSuite:
|
||||
suite = EvalSuite(
|
||||
name="sql Tools Evaluation",
|
||||
system_message=(
|
||||
"You are an AI assistant with access to sql tools. "
|
||||
"Use them to help the user with their tasks."
|
||||
),
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get user by id (schema known)",
|
||||
user_message="Tell me the name and email of user #1 in my database. The table 'users' has the following schema: id: int, name: str, email: str, password_hash: str, created_at: datetime, updated_at: datetime",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=execute_query, args={"query": "SELECT name, email FROM users WHERE id = 1"}
|
||||
)
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[SimilarityCritic(critic_field="query", weight=1.0)],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Discover tables",
|
||||
user_message="What tables are in my database?",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(func=discover_tables, args={}),
|
||||
],
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get table schema (schema provided)",
|
||||
user_message="What columns are in the table 'public.users' in my database?",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="schema_name", weight=0.5),
|
||||
BinaryCritic(critic_field="table_name", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get table schema (schema not provided)",
|
||||
user_message="What columns are in the table 'users' in my database?",
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "When not provided, the schema is 'public'."}
|
||||
],
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="schema_name", weight=0.5),
|
||||
BinaryCritic(critic_field="table_name", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
65
toolkits/postgres/pyproject.toml
Normal file
65
toolkits/postgres/pyproject.toml
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
[build-system]
|
||||
requires = [ "hatchling",]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arcade_postgres"
|
||||
version = "0.1.0"
|
||||
description = "Tools to query and explore a postgres database"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-tdk>=2.0.0,<3.0.0",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"pydantic>=2.11.7",
|
||||
"sqlalchemy>=2.0.41",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"asyncpg>=0.30.0",
|
||||
"greenlet>=3.2.3",
|
||||
]
|
||||
[[project.authors]]
|
||||
name = "evantahler"
|
||||
email = "support@arcade.dev"
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"arcade-ai[evals]>=2.0.0,<3.0.0",
|
||||
"arcade-serve>=2.0.0,<3.0.0",
|
||||
"pytest>=8.3.0,<8.4.0",
|
||||
"pytest-cov>=4.0.0,<4.1.0",
|
||||
"pytest-mock>=3.11.1,<3.12.0",
|
||||
"pytest-asyncio>=0.24.0,<0.25.0",
|
||||
"mypy>=1.5.1,<1.6.0",
|
||||
"pre-commit>=3.4.0,<3.5.0",
|
||||
"tox>=4.11.1,<4.12.0",
|
||||
"ruff>=0.7.4,<0.8.0",
|
||||
]
|
||||
|
||||
# Use local path sources for arcade libs when working locally
|
||||
[tool.uv.sources]
|
||||
arcade-ai = { path = "../../", editable = true }
|
||||
arcade-serve = { path = "../../libs/arcade-serve/", editable = true }
|
||||
arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true }
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
files = [ "arcade_postgres/**/*.py",]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
check_untyped_defs = "True"
|
||||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = [ "tests",]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [ "arcade_postgres",]
|
||||
0
toolkits/postgres/tests/__init__.py
Normal file
0
toolkits/postgres/tests/__init__.py
Normal file
114
toolkits/postgres/tests/dump.sql
Normal file
114
toolkits/postgres/tests/dump.sql
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
DROP TABLE IF EXISTS "public"."messages";
|
||||
-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup.
|
||||
-- Sequence and defined type
|
||||
CREATE SEQUENCE IF NOT EXISTS messages_id_seq;
|
||||
-- Table Definition
|
||||
CREATE TABLE "public"."messages" (
|
||||
"id" int4 NOT NULL DEFAULT nextval('messages_id_seq'::regclass),
|
||||
"body" text NOT NULL,
|
||||
"user_id" int4 NOT NULL,
|
||||
"created_at" timestamp NOT NULL DEFAULT now(),
|
||||
"updated_at" timestamp NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY ("id")
|
||||
);
|
||||
DROP TABLE IF EXISTS "public"."users";
|
||||
-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup.
|
||||
-- Sequence and defined type
|
||||
CREATE SEQUENCE IF NOT EXISTS users_id_seq;
|
||||
-- Table Definition
|
||||
CREATE TABLE "public"."users" (
|
||||
"id" int4 NOT NULL DEFAULT nextval('users_id_seq'::regclass),
|
||||
"name" varchar(256) NOT NULL,
|
||||
"email" text NOT NULL,
|
||||
"password_hash" text NOT NULL,
|
||||
"created_at" timestamp NOT NULL DEFAULT now(),
|
||||
"updated_at" timestamp NOT NULL DEFAULT now(),
|
||||
"status" varchar,
|
||||
PRIMARY KEY ("id")
|
||||
);
|
||||
INSERT INTO "public"."messages" (
|
||||
"id",
|
||||
"body",
|
||||
"user_id",
|
||||
"created_at",
|
||||
"updated_at"
|
||||
)
|
||||
VALUES (
|
||||
1,
|
||||
'Evan says hello',
|
||||
3,
|
||||
'2025-04-10 17:21:05.504468',
|
||||
'2025-04-10 17:21:05.504468'
|
||||
),
|
||||
(
|
||||
5100,
|
||||
'Hello! The current time is 2025-01-13T14:38:39.204Z',
|
||||
12,
|
||||
'2025-01-13 06:38:39.210897',
|
||||
'2025-01-13 06:38:39.210897'
|
||||
),
|
||||
(
|
||||
5101,
|
||||
'Hello! The current time is 2025-01-13T14:55:32.560Z',
|
||||
12,
|
||||
'2025-01-13 06:55:32.56934',
|
||||
'2025-01-13 06:55:32.56934'
|
||||
),
|
||||
(
|
||||
5102,
|
||||
'Hello! The current time is 2025-01-13T15:00:37.250Z',
|
||||
12,
|
||||
'2025-01-13 07:00:37.261816',
|
||||
'2025-01-13 07:00:37.261816'
|
||||
),
|
||||
(
|
||||
5319,
|
||||
'Hello! The current time is 2025-01-14T07:17:07.115Z',
|
||||
12,
|
||||
'2025-01-13 23:17:07.123393',
|
||||
'2025-01-13 23:17:07.123393'
|
||||
);
|
||||
INSERT INTO "public"."users" (
|
||||
"id",
|
||||
"name",
|
||||
"email",
|
||||
"password_hash",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"status"
|
||||
)
|
||||
VALUES (
|
||||
1,
|
||||
'Mario',
|
||||
'mario@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
|
||||
'2024-09-01 20:49:38.759432',
|
||||
'2024-09-02 03:49:39.927',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
3,
|
||||
'Evan',
|
||||
'evantahler@gmail.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
|
||||
'2024-09-02 17:49:23.377425',
|
||||
'2024-09-02 17:49:23.377425',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
12,
|
||||
'Admin',
|
||||
'admin@arcade.dev',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
|
||||
'2024-10-13 15:01:30.792909',
|
||||
'2024-10-13 15:01:30.792909',
|
||||
'inactive'
|
||||
);
|
||||
ALTER TABLE "public"."messages"
|
||||
ADD FOREIGN KEY ("user_id") REFERENCES "public"."users"("id");
|
||||
-- set pk to 13
|
||||
ALTER SEQUENCE users_id_seq RESTART WITH 13;
|
||||
-- Indices
|
||||
CREATE UNIQUE INDEX name_idx ON public.users USING btree (name);
|
||||
CREATE UNIQUE INDEX email_idx ON public.users USING btree (email);
|
||||
CREATE UNIQUE INDEX users_email_unique ON public.users USING btree (email);
|
||||
119
toolkits/postgres/tests/test_postgres.py
Normal file
119
toolkits/postgres/tests/test_postgres.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import os
|
||||
from os import environ
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from arcade_postgres.tools.postgres import (
|
||||
DatabaseEngine,
|
||||
discover_schemas,
|
||||
discover_tables,
|
||||
execute_query,
|
||||
get_table_schema,
|
||||
)
|
||||
from arcade_tdk import ToolContext, ToolSecretItem
|
||||
from arcade_tdk.errors import RetryableToolError
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
DATABASE_CONNECTION_STRING = (
|
||||
environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING")
|
||||
or "postgresql://evan@localhost:5432/postgres"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = ToolContext()
|
||||
context.secrets = []
|
||||
context.secrets.append(
|
||||
ToolSecretItem(key="DATABASE_CONNECTION_STRING", value=DATABASE_CONNECTION_STRING)
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
# before the tests, restore the database from the dump
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def restore_database():
|
||||
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
|
||||
engine = create_async_engine(
|
||||
DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split("?")[0]
|
||||
)
|
||||
async with engine.connect() as c:
|
||||
queries = f.read().split(";")
|
||||
await c.execute(text("BEGIN"))
|
||||
for query in queries:
|
||||
if query.strip():
|
||||
await c.execute(text(query))
|
||||
await c.commit()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def cleanup_engines():
|
||||
"""Clean up database engines after each test to prevent connection leaks."""
|
||||
yield
|
||||
# Clean up all cached engines after each test
|
||||
await DatabaseEngine.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_schemas(mock_context) -> None:
|
||||
assert await discover_schemas(mock_context) == ["public"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_tables(mock_context) -> None:
|
||||
assert await discover_tables(mock_context) == ["users", "messages"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema(mock_context) -> None:
|
||||
assert await get_table_schema(mock_context, "public", "users") == [
|
||||
"id: int (PRIMARY KEY)",
|
||||
"name: str (INDEXED)",
|
||||
"email: str (INDEXED)",
|
||||
"password_hash: str",
|
||||
"created_at: datetime",
|
||||
"updated_at: datetime",
|
||||
"status: str",
|
||||
]
|
||||
|
||||
assert await get_table_schema(mock_context, "public", "messages") == [
|
||||
"id: int (PRIMARY KEY)",
|
||||
"body: str",
|
||||
"user_id: int",
|
||||
"created_at: datetime",
|
||||
"updated_at: datetime",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query(mock_context) -> None:
|
||||
assert await execute_query(mock_context, "SELECT id, name, email FROM users WHERE id = 1") == [
|
||||
"(1, 'Mario', 'mario@example.com')"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_with_no_results(mock_context) -> None:
|
||||
# does not raise an error
|
||||
assert await execute_query(mock_context, "SELECT * FROM users WHERE id = 9999999999") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_with_problem(mock_context) -> None:
|
||||
# 'foo' is not a valid id
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await execute_query(mock_context, "SELECT * FROM users WHERE id = 'foo'")
|
||||
assert "invalid input syntax" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_rejects_non_select(mock_context) -> None:
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await execute_query(
|
||||
mock_context,
|
||||
"INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
|
||||
)
|
||||
assert "Only SELECT queries are allowed" in str(e.value)
|
||||
Loading…
Reference in a new issue