Clickhouse Toolkit (#527)

Generated from the postgres toolkit
This commit is contained in:
Evan Tahler 2025-08-05 14:05:09 -07:00 committed by GitHub
parent 7888dc505e
commit a85fa76997
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1326 additions and 13 deletions

116
.github/workflows/test-toolkits.yml vendored Normal file
View file

@ -0,0 +1,116 @@
name: Test Toolkits
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
jobs:
setup:
runs-on: ubuntu-latest
outputs:
toolkits_with_gha_secrets: ${{ steps.load_toolkits.outputs.toolkits_with_gha_secrets }}
toolkits_without_gha_secrets: ${{ steps.load_toolkits.outputs.toolkits_without_gha_secrets }}
steps:
- name: Check out
uses: actions/checkout@v4
- name: determine toolkits with and without GHA secrets
id: load_toolkits
run: |
# Find all directories in toolkits/ that have a pyproject.toml
TOOLKITS=$(find toolkits -maxdepth 1 -type d -not -name "toolkits" -exec test -f {}/pyproject.toml \; -exec basename {} \; | jq -R -s -c 'split("\n")[:-1]')
TOOLKITS_WITH_GHA_SECRETS='["postgres", "clickhouse"]'
TOOLKITS_WITHOUT_GHA_SECRETS=$(echo "$TOOLKITS" | jq -c --argjson with "$TOOLKITS_WITH_GHA_SECRETS" '[.[] | select(. as $t | $with | index($t) | not)]')
echo "Found toolkits: $TOOLKITS"
echo "Found toolkits without GHA secrets: $TOOLKITS_WITHOUT_GHA_SECRETS"
echo "Found toolkits with GHA secrets: $TOOLKITS_WITH_GHA_SECRETS"
echo "toolkits_without_gha_secrets=$TOOLKITS_WITHOUT_GHA_SECRETS" >> $GITHUB_OUTPUT
echo "toolkits_with_gha_secrets=$TOOLKITS_WITH_GHA_SECRETS" >> $GITHUB_OUTPUT
test-toolkits:
needs: setup
runs-on: ubuntu-latest
strategy:
matrix:
toolkit: ${{ fromJson(needs.setup.outputs.toolkits_without_gha_secrets) }}
fail-fast: true
steps:
- name: Check out
uses: actions/checkout@v4
- name: Set up the environment
uses: ./.github/actions/setup-uv-env
- name: Install toolkit dependencies
working-directory: toolkits/${{ matrix.toolkit }}
run: uv pip install -e ".[dev]"
- name: Check toolkit
working-directory: toolkits/${{ matrix.toolkit }}
run: |
uv run --active pre-commit run -a
uv run --active mypy --config-file=pyproject.toml
- name: Test stand-alone toolkits (no secrets)
working-directory: toolkits/${{ matrix.toolkit }}
run: |
# Run pytest and capture exit code
uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$?
if [ "${EXIT_CODE:-0}" -eq 5 ]; then
echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..."
exit 0
elif [ "${EXIT_CODE:-0}" -ne 0 ]; then
exit ${EXIT_CODE}
fi
test-toolkits-with-gha-secrets:
needs: setup
runs-on: ubuntu-latest
strategy:
matrix:
toolkit: ${{ fromJson(needs.setup.outputs.toolkits_with_gha_secrets) }}
fail-fast: true
steps:
- name: Check out
uses: actions/checkout@v4
- name: Set up the environment
uses: ./.github/actions/setup-uv-env
- name: Install toolkit dependencies
working-directory: toolkits/${{ matrix.toolkit }}
run: uv pip install -e ".[dev]"
- name: Check toolkit
working-directory: toolkits/${{ matrix.toolkit }}
run: |
uv run --active pre-commit run -a
uv run --active mypy --config-file=pyproject.toml
- name: Test stand-alone toolkits (with secrets)
if: |
!github.event.pull_request.head.repo.fork
working-directory: toolkits/${{ matrix.toolkit }}
env:
TEST_POSTGRES_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_POSTGRES_DATABASE_CONNECTION_STRING }} # TODO: dynamically only load the `TEST_${{ matrix.toolkit }}_DATABASE_CONNECTION_STRING secret`
TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING }}
run: |
# If there's a custom test_setup.sh file, run it
if [ -f tests/test_setup.sh ]; then
echo "Running custom test setup for ${{ matrix.toolkit }}..."
./tests/test_setup.sh
fi
# Run pytest and capture exit code
uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$?
if [ "${EXIT_CODE:-0}" -eq 5 ]; then
echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..."
exit 0
elif [ "${EXIT_CODE:-0}" -ne 0 ]; then
exit ${EXIT_CODE}
fi

View file

@ -0,0 +1,53 @@
.PHONY: help
help:
@echo "🛠️ clickhouse Commands:\n"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: install
install: ## Install the uv environment and install all packages with dependencies
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras --no-sources
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv"
.PHONY: install-local
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv"
.PHONY: build
build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
uv build
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning dist directory"
rm -rf dist
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
coverage report
@echo "Generating coverage report"
coverage html
.PHONY: bump-version
bump-version: ## Bump the version in the pyproject.toml file by a patch version
@echo "🚀 Bumping version in pyproject.toml"
uv version --bump patch
.PHONY: check
check: ## Run code quality tools.
@echo "🚀 Linting code: Running pre-commit"
@uv run pre-commit run -a
@echo "🚀 Static type checking: Running mypy"
@uv run mypy --config-file=pyproject.toml

View file

@ -0,0 +1,209 @@
import contextlib
from typing import Any, ClassVar
from urllib.parse import urlparse
import clickhouse_connect
from arcade_tdk.errors import RetryableToolError
MAX_ROWS_RETURNED = 1000
TEST_QUERY = "SELECT 1"
class DatabaseEngine:
_instance: ClassVar[None] = None
_clients: ClassVar[dict[str, Any]] = {}
@classmethod
async def get_instance(cls, connection_string: str) -> Any:
parsed_url = urlparse(connection_string)
# Extract connection parameters from the URL
host = parsed_url.hostname or "localhost"
port = parsed_url.port
database = parsed_url.path.lstrip("/") or "default"
username = parsed_url.username
password = parsed_url.password
# Handle different ClickHouse protocols
# clickhouse-connect only supports HTTP and HTTPS interfaces
if parsed_url.scheme in ["clickhouse+native"]:
# Convert native protocol to HTTP for clickhouse-connect compatibility
# Convert native port 9000 to HTTP port 8123
port = 8123 if port == 9000 else port or 8123
interface = "http"
elif parsed_url.scheme in ["clickhouse+https"]:
# For HTTPS protocol
port = port or 8443
interface = "https"
else:
# For HTTP or unspecified, use port 8123 by default
port = port or 8123
interface = "http"
key = f"{interface}://{host}:{port}/{database}"
if key not in cls._clients:
try:
# Create ClickHouse client
client_args: dict[str, Any] = {
"host": host,
"port": port,
"database": database,
"interface": interface,
}
if username:
client_args["username"] = username
if password:
client_args["password"] = password
client = clickhouse_connect.get_client(**client_args)
cls._clients[key] = client
# Test the connection
client.command(TEST_QUERY)
except Exception as e:
# Remove failed client from cache
cls._clients.pop(key, None)
raise RetryableToolError(
f"Connection failed: {e}",
developer_message="Connection to ClickHouse failed.",
additional_prompt_content="Check the connection string and try again.",
) from e
return cls._clients[key]
@classmethod
async def get_engine(cls, connection_string: str) -> Any:
client = await cls.get_instance(connection_string)
class ConnectionContextManager:
def __init__(self, client: Any) -> None:
self.client = client
async def __aenter__(self) -> Any:
return self.client
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# Connection cleanup is handled by clickhouse-connect
pass
return ConnectionContextManager(client)
@classmethod
async def cleanup(cls) -> None:
"""Clean up all cached clients. Call this when shutting down."""
for client in cls._clients.values():
with contextlib.suppress(Exception):
client.close()
cls._clients.clear()
@classmethod
def clear_cache(cls) -> None:
"""Clear the client cache without disposing clients. Use with caution."""
cls._clients.clear()
@classmethod
def sanitize_query( # noqa: C901
cls,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> tuple[str, dict[str, Any]]:
# Remove the leading keywords from the clauses if they are present
if select_clause.strip().split(" ")[0].upper() == "SELECT":
select_clause = select_clause.strip()[6:]
if from_clause.strip().split(" ")[0].upper() == "FROM":
from_clause = from_clause.strip()[4:]
if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN":
join_clause = join_clause.strip()[4:]
if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE":
where_clause = where_clause.strip()[5:]
if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY":
group_by_clause = group_by_clause.strip()[8:]
if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY":
order_by_clause = order_by_clause.strip()[8:]
if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING":
having_clause = having_clause.strip()[6:]
first_select_word = select_clause.strip().split(" ")[0].upper()
if first_select_word in [
"INSERT",
"UPDATE",
"DELETE",
"CREATE",
"ALTER",
"DROP",
"TRUNCATE",
"REINDEX",
"VACUUM",
"ANALYZE",
"COMMENT",
"OPTIMIZE", # ClickHouse-specific
"SYSTEM", # ClickHouse-specific
]:
raise RetryableToolError(
"Only SELECT queries are allowed.",
)
if select_clause.strip() == "*":
raise RetryableToolError(
"Do not use * in the select clause. Use a comma separated list of columns you wish to return.",
)
if limit > MAX_ROWS_RETURNED:
raise RetryableToolError(
f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.",
)
if offset < 0:
raise RetryableToolError(
"Offset must be greater than or equal to 0.",
developer_message="Offset must be greater than or equal to 0.",
)
if limit <= 0:
raise RetryableToolError(
"Limit must be greater than 0.",
developer_message="Limit must be greater than 0.",
)
# Build query with identifiers directly interpolated, but use parameters for values
parts = []
if with_clause:
parts.append(f"WITH {with_clause}")
parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608
if join_clause:
parts.append(f"JOIN {join_clause}")
if where_clause:
parts.append(f"WHERE {where_clause}")
if group_by_clause:
parts.append(f"GROUP BY {group_by_clause}")
if having_clause:
parts.append(f"HAVING {having_clause}")
if order_by_clause:
parts.append(f"ORDER BY {order_by_clause}")
parts.append("LIMIT :limit OFFSET :offset")
query = " ".join(parts)
# Only use parameters for values, not identifiers
parameters = {
"limit": limit,
"offset": offset,
}
return query, parameters

View file

@ -0,0 +1,291 @@
from typing import Annotated, Any
from arcade_tdk import ToolContext, tool
from arcade_tdk.errors import RetryableToolError
from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine
@tool(requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"])
async def discover_schemas(
context: ToolContext,
) -> list[str]:
"""Discover all the schemas in the ClickHouse database.
Note: ClickHouse doesn't have schemas like PostgreSQL, so this returns a default schema name.
"""
# ClickHouse doesn't have schemas like PostgreSQL, but we return a default for compatibility
return ["default"]
@tool(requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"])
async def discover_databases(
context: ToolContext,
) -> list[str]:
"""Discover all the databases in the ClickHouse database."""
async with await DatabaseEngine.get_engine(
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
) as client:
databases = await _get_databases(client)
return databases
@tool(requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"])
async def discover_tables(
context: ToolContext,
) -> list[str]:
"""Discover all the tables in the ClickHouse database when the list of tables is not known.
ALWAYS use this tool before any other tool that requires a table name.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
) as client:
tables = await _get_tables(client, "default")
return tables
@tool(requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"])
async def get_table_schema(
context: ToolContext,
schema_name: Annotated[str, "The schema to get the table schema of"],
table_name: Annotated[str, "The table to get the schema of"],
) -> list[str]:
"""
Get the schema/structure of a ClickHouse table in the ClickHouse database when the schema is not known, and the name of the table is provided.
This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the <DiscoverTables> tool.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
) as client:
return await _get_table_schema(client, "default", table_name)
@tool(requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"])
async def execute_select_query(
context: ToolContext,
select_clause: Annotated[
str,
"This is the part of the SQL query that comes after the SELECT keyword wish a comma separated list of columns you wish to return. Do not include the SELECT keyword.",
],
from_clause: Annotated[
str,
"This is the part of the SQL query that comes after the FROM keyword. Do not include the FROM keyword.",
],
limit: Annotated[
int,
"The maximum number of rows to return. This is the LIMIT clause of the query. Default: 100.",
] = 100,
offset: Annotated[
int, "The number of rows to skip. This is the OFFSET clause of the query. Default: 0."
] = 0,
join_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the JOIN keyword. Do not include the JOIN keyword. If no join is needed, leave this blank.",
] = None,
where_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the WHERE keyword. Do not include the WHERE keyword. If no where clause is needed, leave this blank.",
] = None,
having_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the HAVING keyword. Do not include the HAVING keyword. If no having clause is needed, leave this blank.",
] = None,
group_by_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the GROUP BY keyword. Do not include the GROUP BY keyword. If no group by clause is needed, leave this blank.",
] = None,
order_by_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the ORDER BY keyword. Do not include the ORDER BY keyword. If no order by clause is needed, leave this blank.",
] = None,
with_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the WITH keyword when basing the query on a virtual table. If no WITH clause is needed, leave this blank.",
] = None,
) -> list[str]:
"""
You have a connection to a ClickHouse database.
Execute a SELECT query and return the results against the ClickHouse database. No other queries (INSERT, UPDATE, DELETE, etc.) are allowed.
ONLY use this tool if you have already loaded the schema of the tables you need to query. Use the <GetTableSchema> tool to load the schema if not already known.
The final query will be constructed as follows:
SELECT {select_query_part} FROM {from_clause} JOIN {join_clause} WHERE {where_clause} HAVING {having_clause} ORDER BY {order_by_clause} LIMIT {limit} OFFSET {offset}
When running queries, follow these rules which will help avoid errors:
* Never "select *" from a table. Always select the columns you need.
* Always order your results by the most important columns first. If you aren't sure, order by the primary key.
* Always use case-insensitive queries to match strings in the query.
* Always trim strings in the query.
* Prefer LIKE queries over direct string matches or regex queries.
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
* ClickHouse is case-sensitive, so be careful with table and column names.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
) as client:
try:
return await _execute_query(
client,
select_clause=select_clause,
from_clause=from_clause,
limit=limit,
offset=offset,
join_clause=join_clause,
where_clause=where_clause,
having_clause=having_clause,
group_by_clause=group_by_clause,
order_by_clause=order_by_clause,
with_clause=with_clause,
)
except Exception as e:
raise RetryableToolError(
f"Query failed: {e}",
developer_message=f"Query failed with parameters: select_clause={select_clause}, from_clause={from_clause}, limit={limit}, offset={offset}, join_clause={join_clause}, where_clause={where_clause}, having_clause={having_clause}, order_by_clause={order_by_clause}, with_clause={with_clause}.",
additional_prompt_content="Load the database schema <GetTableSchema> or use the <DiscoverTables> tool to discover the tables and try again.",
retry_after_ms=10,
) from e
async def _get_databases(client: Any) -> list[str]:
"""Get all the databases in ClickHouse"""
# ClickHouse uses SHOW DATABASES instead of information_schema
result = client.query("SHOW DATABASES")
databases = [row[0] for row in result.result_rows]
# Filter out system databases
system_databases = {
"system",
"information_schema",
"INFORMATION_SCHEMA",
"default",
"temporary_tables",
"temporary_tables_metadata",
}
databases = [db for db in databases if db not in system_databases]
databases.sort()
return databases
async def _get_tables(client: Any, database_name: str) -> list[str]:
"""Get all the tables in the specified ClickHouse database"""
# ClickHouse uses SHOW TABLES FROM database_name
result = client.query(f"SHOW TABLES FROM {database_name}")
tables = [row[0] for row in result.result_rows]
tables.sort()
return tables
async def _get_table_schema(client: Any, database_name: str, table_name: str) -> list[str]:
"""Get the schema of a ClickHouse table"""
# ClickHouse uses DESCRIBE TABLE database_name.table_name
result = client.query(f"DESCRIBE TABLE {database_name}.{table_name}")
columns = result.result_rows
# Get primary key information
# ClickHouse doesn't have traditional primary keys like PostgreSQL
# Instead, it has sorting keys and primary keys that are part of the table engine
try:
pk_result = client.query(f"SHOW CREATE TABLE {database_name}.{table_name}")
if pk_result.result_rows:
create_statement = pk_result.result_rows[0][0]
# Parse the CREATE statement to extract primary key information
primary_keys = _extract_primary_keys_from_create_statement(create_statement)
else:
primary_keys = set()
except Exception:
primary_keys = set()
results = []
for column in columns:
column_name = column[
0
] # ClickHouse DESCRIBE returns: name, type, default_type, default_expression, comment, codec_expression, ttl_expression
column_type = column[1]
# Build column description
description = f"{column_name}: {column_type}"
# Add primary key indicator
if column_name in primary_keys:
description += " (PRIMARY KEY)"
# Add default value if present
if len(column) > 3 and column[3]: # default_expression
description += f" DEFAULT {column[3]}"
# Add comment if present
if len(column) > 4 and column[4]: # comment
description += f" COMMENT '{column[4]}'"
results.append(description)
return results[:MAX_ROWS_RETURNED]
def _extract_primary_keys_from_create_statement(create_statement: str) -> set[str]:
"""Extract primary key columns from ClickHouse CREATE TABLE statement"""
primary_keys = set()
# Look for PRIMARY KEY clause
import re
pk_match = re.search(r"PRIMARY KEY\s*\(([^)]+)\)", create_statement, re.IGNORECASE)
if pk_match:
pk_columns = pk_match.group(1).split(",")
for col in pk_columns:
primary_keys.add(col.strip().strip("`"))
# Look for ORDER BY clause (which can also indicate primary key)
order_match = re.search(r"ORDER BY\s*\(([^)]+)\)", create_statement, re.IGNORECASE)
if order_match:
order_columns = order_match.group(1).split(",")
for col in order_columns:
primary_keys.add(col.strip().strip("`"))
return primary_keys
async def _execute_query(
client: Any,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> list[str]:
"""Execute a query and return the results."""
query, parameters = DatabaseEngine.sanitize_query(
select_clause=select_clause,
from_clause=from_clause,
limit=limit,
offset=offset,
join_clause=join_clause,
where_clause=where_clause,
having_clause=having_clause,
group_by_clause=group_by_clause,
order_by_clause=order_by_clause,
with_clause=with_clause,
)
print(f"Query: {query}")
print(f"Parameters: {parameters}")
# For clickhouse-connect, we need to substitute parameters manually
# since it doesn't use SQLAlchemy-style parameter binding
formatted_query = query
for param_name, param_value in parameters.items():
formatted_query = formatted_query.replace(f":{param_name}", str(param_value))
result = client.query(formatted_query)
rows = result.result_rows
results = [str(row) for row in rows]
return results[:MAX_ROWS_RETURNED]

View file

@ -0,0 +1,66 @@
[build-system]
requires = [ "hatchling",]
build-backend = "hatchling.build"
[project]
name = "arcade_clickhouse"
version = "0.1.0"
description = "Tools to query and explore a ClickHouse database"
requires-python = ">=3.10"
dependencies = [
"arcade-tdk>=2.0.0,<3.0.0",
"clickhouse-connect>=0.7.0",
"pydantic>=2.11.7",
"sqlalchemy>=2.0.41",
"clickhouse-sqlalchemy>=0.2.0",
"greenlet>=3.2.3",
"aiochsa>=0.1.0",
"setuptools>=80.9.0",
]
[[project.authors]]
name = "evantahler"
email = "support@arcade.dev"
[project.optional-dependencies]
dev = [
"arcade-ai[evals]>=2.0.0,<3.0.0",
"arcade-serve>=2.0.0,<3.0.0",
"pytest>=8.3.0,<8.4.0",
"pytest-cov>=4.0.0,<4.1.0",
"pytest-mock>=3.11.1,<3.12.0",
"pytest-asyncio>=0.24.0,<0.25.0",
"mypy>=1.5.1,<1.6.0",
"pre-commit>=3.4.0,<3.5.0",
"tox>=4.11.1,<4.12.0",
"ruff>=0.7.4,<0.8.0",
]
# Use local path sources for arcade libs when working locally
[tool.uv.sources]
arcade-ai = { path = "../../", editable = true }
arcade-serve = { path = "../../libs/arcade-serve/", editable = true }
arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true }
[tool.mypy]
files = [ "arcade_clickhouse/**/*.py",]
python_version = "3.10"
disallow_untyped_defs = "True"
disallow_any_unimported = "True"
no_implicit_optional = "True"
check_untyped_defs = "True"
warn_return_any = "True"
warn_unused_ignores = "True"
show_error_codes = "True"
ignore_missing_imports = "True"
[tool.pytest.ini_options]
testpaths = [ "tests",]
asyncio_default_fixture_loop_scope = "function"
[tool.coverage.report]
skip_empty = true
[tool.hatch.build.targets.wheel]
packages = [ "arcade_clickhouse",]

View file

View file

@ -0,0 +1,369 @@
-- ClickHouse test database setup
-- This file contains sample data for testing the ClickHouse toolkit
-- Create users table
CREATE TABLE IF NOT EXISTS default.users (
id UInt32,
name String,
email String,
password_hash String,
created_at DateTime,
updated_at DateTime,
status String
) ENGINE = MergeTree()
ORDER BY (id, created_at);
-- Create messages table
CREATE TABLE IF NOT EXISTS default.messages (
id UInt32,
body String,
user_id UInt32,
created_at DateTime,
updated_at DateTime
) ENGINE = MergeTree()
ORDER BY (id, created_at);
-- Insert sample data into users table
INSERT INTO default.users (
id,
name,
email,
password_hash,
created_at,
updated_at,
status
)
VALUES (
1,
'Alice',
'alice@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
'2024-09-01 20:49:38',
'2024-09-02 03:49:39',
'active'
),
(
2,
'Bob',
'bob@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
'2024-09-02 17:49:23',
'2024-09-02 17:49:23',
'active'
),
(
3,
'Charlie',
'charlie@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
'2024-09-03 10:30:15',
'2024-09-03 10:30:15',
'active'
),
(
4,
'Diana',
'diana@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123',
'2024-09-04 14:20:30',
'2024-09-04 14:20:30',
'active'
),
(
5,
'Evan',
'evan@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456',
'2024-09-05 09:15:45',
'2024-09-05 09:15:45',
'active'
),
(
6,
'Fiona',
'fiona@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789',
'2024-09-06 16:45:12',
'2024-09-06 16:45:12',
'active'
),
(
7,
'George',
'george@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012',
'2024-09-07 11:30:25',
'2024-09-07 11:30:25',
'active'
),
(
8,
'Helen',
'helen@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345',
'2024-09-08 13:25:40',
'2024-09-08 13:25:40',
'active'
),
(
9,
'Ian',
'ian@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678',
'2024-09-09 08:40:55',
'2024-09-09 08:40:55',
'active'
),
(
10,
'Julia',
'julia@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901',
'2024-09-10 15:55:18',
'2024-09-10 15:55:18',
'active'
);
-- Insert sample data into messages table
INSERT INTO default.messages (id, body, user_id, created_at, updated_at)
VALUES (
1,
'Hello everyone!',
1,
'2025-01-10 10:00:00',
'2025-01-10 10:00:00'
),
(
2,
'How is everyone doing today?',
1,
'2025-01-10 11:30:00',
'2025-01-10 11:30:00'
),
(
3,
'Great to see you all here!',
1,
'2025-01-10 14:15:00',
'2025-01-10 14:15:00'
),
(
4,
'Hi Alice! Doing well, thanks for asking.',
2,
'2025-01-10 11:35:00',
'2025-01-10 11:35:00'
),
(
5,
'Anyone up for a game later?',
2,
'2025-01-10 16:20:00',
'2025-01-10 16:20:00'
),
(
6,
'Count me in for the game!',
3,
'2025-01-10 16:25:00',
'2025-01-10 16:25:00'
),
(
7,
'What time works for everyone?',
3,
'2025-01-10 16:30:00',
'2025-01-10 16:30:00'
),
(
8,
'I can play around 8 PM',
3,
'2025-01-10 17:00:00',
'2025-01-10 17:00:00'
),
(
9,
'8 PM works for me too!',
4,
'2025-01-10 17:05:00',
'2025-01-10 17:05:00'
),
(
10,
'What game should we play?',
4,
'2025-01-10 17:10:00',
'2025-01-10 17:10:00'
),
(
11,
'I suggest we try the new arcade game!',
5,
'2025-01-10 17:15:00',
'2025-01-10 17:15:00'
),
(
12,
'It has great multiplayer features',
5,
'2025-01-10 17:20:00',
'2025-01-10 17:20:00'
),
(
13,
'Perfect timing for a weekend session',
5,
'2025-01-10 18:00:00',
'2025-01-10 18:00:00'
),
(
26,
'Just finished setting up the game server!',
5,
'2025-01-10 20:00:00',
'2025-01-10 20:00:00'
),
(
27,
'Everyone should be able to connect now',
5,
'2025-01-10 20:05:00',
'2025-01-10 20:05:00'
),
(
28,
'I added some custom maps too',
5,
'2025-01-10 20:10:00',
'2025-01-10 20:10:00'
),
(
29,
'The graphics look amazing on this new version',
5,
'2025-01-10 20:15:00',
'2025-01-10 20:15:00'
),
(
30,
'Hope you all enjoy the new features',
5,
'2025-01-10 20:20:00',
'2025-01-10 20:20:00'
),
(
31,
'I also set up a leaderboard system',
5,
'2025-01-10 20:25:00',
'2025-01-10 20:25:00'
),
(
32,
'We can track high scores now',
5,
'2025-01-10 20:30:00',
'2025-01-10 20:30:00'
),
(
33,
'The game supports up to 8 players simultaneously',
5,
'2025-01-10 20:35:00',
'2025-01-10 20:35:00'
),
(
34,
'I tested it earlier and it runs smoothly',
5,
'2025-01-10 20:40:00',
'2025-01-10 20:40:00'
),
(
35,
'Cannot wait to see everyone online tonight!',
5,
'2025-01-10 20:45:00',
'2025-01-10 20:45:00'
),
(
14,
'Sounds like fun! I love arcade games.',
6,
'2025-01-10 18:05:00',
'2025-01-10 18:05:00'
),
(
15,
'Should I bring snacks?',
6,
'2025-01-10 18:10:00',
'2025-01-10 18:10:00'
),
(
16,
'Snacks are always welcome!',
7,
'2025-01-10 18:15:00',
'2025-01-10 18:15:00'
),
(
17,
'I can bring some drinks',
7,
'2025-01-10 18:20:00',
'2025-01-10 18:20:00'
),
(
18,
'This is going to be awesome',
7,
'2025-01-10 19:00:00',
'2025-01-10 19:00:00'
),
(
19,
'I agree! Cannot wait for the game night.',
8,
'2025-01-10 19:05:00',
'2025-01-10 19:05:00'
),
(
20,
'Should we set up a Discord call?',
8,
'2025-01-10 19:10:00',
'2025-01-10 19:10:00'
),
(
21,
'Discord would be perfect for voice chat',
9,
'2025-01-10 19:15:00',
'2025-01-10 19:15:00'
),
(
22,
'I will create a server for us',
9,
'2025-01-10 19:20:00',
'2025-01-10 19:20:00'
),
(
23,
'Link will be shared in a few minutes',
9,
'2025-01-10 19:25:00',
'2025-01-10 19:25:00'
),
(
24,
'Thanks Ian! You are the best.',
10,
'2025-01-10 19:30:00',
'2025-01-10 19:30:00'
),
(
25,
'See you all at 8 PM!',
10,
'2025-01-10 19:35:00',
'2025-01-10 19:35:00'
);

View file

@ -0,0 +1,200 @@
import os
from os import environ
import pytest
import pytest_asyncio
from arcade_clickhouse.tools.clickhouse import (
DatabaseEngine,
discover_schemas,
discover_tables,
execute_select_query,
get_table_schema,
)
from arcade_tdk import ToolContext, ToolSecretItem
from arcade_tdk.errors import RetryableToolError
CLICKHOUSE_DATABASE_CONNECTION_STRING = (
environ.get("TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING")
or "clickhouse+native://localhost:9000/default"
)
@pytest.fixture
def mock_context():
context = ToolContext()
context.secrets = []
context.secrets.append(
ToolSecretItem(
key="CLICKHOUSE_DATABASE_CONNECTION_STRING", value=CLICKHOUSE_DATABASE_CONNECTION_STRING
)
)
return context
# before the tests, restore the database from the dump
@pytest_asyncio.fixture(autouse=True)
async def restore_database():
import clickhouse_connect
# Create client for database setup
client = clickhouse_connect.get_client(host="localhost", port=8123)
# Clear existing tables first to avoid duplicates
client.command("DROP TABLE IF EXISTS default.messages")
client.command("DROP TABLE IF EXISTS default.users")
# Read and execute the dump file
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
queries = f.read().split(";")
for query in queries:
if query.strip():
client.command(query)
client.close()
@pytest_asyncio.fixture(autouse=True)
async def cleanup_engines():
"""Clean up database engines after each test to prevent connection leaks."""
yield
# Clean up all cached engines after each test
await DatabaseEngine.cleanup()
@pytest.mark.asyncio
async def test_discover_schemas(mock_context) -> None:
assert await discover_schemas(mock_context) == ["default"]
@pytest.mark.asyncio
async def test_discover_tables(mock_context) -> None:
tables = await discover_tables(mock_context)
assert sorted(tables) == ["messages", "users"]
@pytest.mark.asyncio
async def test_get_table_schema(mock_context) -> None:
users_schema = await get_table_schema(mock_context, "default", "users")
expected_users = [
"id: UInt32 (PRIMARY KEY)",
"name: String",
"email: String",
"password_hash: String",
"created_at: DateTime (PRIMARY KEY)",
"updated_at: DateTime",
"status: String",
]
assert users_schema == expected_users
messages_schema = await get_table_schema(mock_context, "default", "messages")
expected_messages = [
"id: UInt32 (PRIMARY KEY)",
"body: String",
"user_id: UInt32",
"created_at: DateTime (PRIMARY KEY)",
"updated_at: DateTime",
]
assert messages_schema == expected_messages
@pytest.mark.asyncio
async def test_execute_select_query(mock_context) -> None:
# Test specific user query with limit
result1 = await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 1",
limit=1,
)
assert result1 == ["(1, 'Alice', 'alice@example.com')"]
# Test query with offset
result2 = await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
order_by_clause="id",
limit=1,
offset=1,
)
assert result2 == ["(2, 'Bob', 'bob@example.com')"]
@pytest.mark.asyncio
async def test_execute_select_query_with_keywords(mock_context) -> None:
result = await execute_select_query(
mock_context,
select_clause="SELECT id, name, email",
from_clause="FROM users",
limit=1,
)
assert result == ["(1, 'Alice', 'alice@example.com')"]
@pytest.mark.asyncio
async def test_execute_select_query_with_join(mock_context) -> None:
result = await execute_select_query(
mock_context,
select_clause="u.id, u.name, u.email, m.id, m.body",
from_clause="users u",
join_clause="messages m ON u.id = m.user_id",
limit=1,
)
assert result == ["(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')"]
@pytest.mark.asyncio
async def test_execute_select_query_with_group_by(mock_context) -> None:
result = await execute_select_query(
mock_context,
select_clause="u.name, COUNT(m.id) AS message_count",
from_clause="messages m",
join_clause="users u ON m.user_id = u.id",
group_by_clause="u.name",
order_by_clause="message_count DESC",
limit=2,
)
assert result == [
"('Evan', 13)",
"('Alice', 3)",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_no_results(mock_context) -> None:
# does not raise an error
assert (
await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 9999999999",
)
== []
)
@pytest.mark.asyncio
async def test_execute_select_query_with_problem(mock_context) -> None:
# 'foo' is not a valid id
with pytest.raises(RetryableToolError) as e:
await execute_select_query(
mock_context,
select_clause="*",
from_clause="users",
where_clause="id = 'foo'",
)
assert "Do not use * in the select clause" in str(e.value)
@pytest.mark.asyncio
async def test_execute_select_query_rejects_non_select(mock_context) -> None:
with pytest.raises(RetryableToolError) as e:
await execute_select_query(
mock_context,
select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
from_clause="users",
)
assert "Only SELECT queries are allowed" in str(e.value)

View file

@ -0,0 +1,3 @@
#!/bin/bash
docker run -d --name some-clickhouse-server --ulimit nofile=262144:262144 -p 8123:8123 -p 8443:8443 -p 9000:9000 yandex/clickhouse-server

View file

@ -8,19 +8,19 @@ from sqlalchemy.ext.asyncio import AsyncEngine
from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
@tool(requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"])
async def discover_schemas(
context: ToolContext,
) -> list[str]:
"""Discover all the schemas in the postgres database."""
async with await DatabaseEngine.get_engine(
context.get_secret("DATABASE_CONNECTION_STRING")
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
) as engine:
schemas = await _get_schemas(engine)
return schemas
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
@tool(requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"])
async def discover_tables(
context: ToolContext,
schema_name: Annotated[
@ -32,13 +32,13 @@ async def discover_tables(
ALWAYS use this tool before any other tool that requires a table name.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("DATABASE_CONNECTION_STRING")
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
) as engine:
tables = await _get_tables(engine, schema_name)
return tables
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
@tool(requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"])
async def get_table_schema(
context: ToolContext,
schema_name: Annotated[str, "The database schema to get the table schema of"],
@ -50,12 +50,12 @@ async def get_table_schema(
This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the <DiscoverTables> tool.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("DATABASE_CONNECTION_STRING")
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
) as engine:
return await _get_table_schema(engine, schema_name, table_name)
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
@tool(requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"])
async def execute_select_query(
context: ToolContext,
select_clause: Annotated[
@ -116,7 +116,7 @@ async def execute_select_query(
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("DATABASE_CONNECTION_STRING")
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
) as engine:
try:
return await _execute_query(
@ -171,6 +171,8 @@ async def _get_tables(engine: AsyncEngine, schema_name: str) -> list[str]:
these_tables = await conn.run_sync(get_table_names)
tables.extend(these_tables)
tables.sort()
return tables

View file

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "arcade_postgres"
version = "0.2.0"
version = "0.3.0"
description = "Tools to query and explore a postgres database"
requires-python = ">=3.10"
dependencies = [

View file

@ -15,7 +15,7 @@ from arcade_tdk.errors import RetryableToolError
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
DATABASE_CONNECTION_STRING = (
POSTGRES_DATABASE_CONNECTION_STRING = (
environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING")
or "postgresql://evan@localhost:5432/postgres"
)
@ -26,7 +26,9 @@ def mock_context():
context = ToolContext()
context.secrets = []
context.secrets.append(
ToolSecretItem(key="DATABASE_CONNECTION_STRING", value=DATABASE_CONNECTION_STRING)
ToolSecretItem(
key="POSTGRES_DATABASE_CONNECTION_STRING", value=POSTGRES_DATABASE_CONNECTION_STRING
)
)
return context
@ -37,7 +39,9 @@ def mock_context():
async def restore_database():
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
engine = create_async_engine(
DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split("?")[0]
POSTGRES_DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split(
"?"
)[0]
)
async with engine.connect() as c:
queries = f.read().split(";")
@ -64,7 +68,7 @@ async def test_discover_schemas(mock_context) -> None:
@pytest.mark.asyncio
async def test_discover_tables(mock_context) -> None:
assert await discover_tables(mock_context) == ["users", "messages"]
assert await discover_tables(mock_context) == ["messages", "users"]
@pytest.mark.asyncio