Remove toolkits (#784)

This commit is contained in:
Eric Gustin 2026-02-26 09:09:46 -08:00 committed by GitHub
parent bcee0f556f
commit 830480de83
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
117 changed files with 5 additions and 12030 deletions

View file

@ -6,22 +6,6 @@ inputs:
required: false
description: "The python version to use"
default: "3.11"
is-toolkit:
required: false
description: "Whether this is a toolkit package"
default: "false"
is-contrib:
required: false
description: "Whether this is a contrib package"
default: "false"
is-lib:
required: false
description: "Whether this is a library package"
default: "false"
working-directory:
required: false
description: "Working directory for the installation (used for toolkits)"
default: "."
runs:
using: "composite"
@ -29,26 +13,8 @@ runs:
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
working-directory: ${{ inputs.working-directory }}
python-version: ${{ inputs.python-version }}
- name: Install toolkit dependencies
if: inputs.is-toolkit == 'true'
working-directory: ${{ inputs.working-directory }}
run: |
echo "Installing dependencies for ${{ inputs.working-directory }}"
make install-local
shell: bash
- name: Install contrib dependencies
if: inputs.is-contrib == 'true'
working-directory: ${{ inputs.working-directory }}
run: |
echo "Installing dependencies for ${{ inputs.working-directory }}"
make install
shell: bash
- name: Install libs dependencies
if: inputs.is-toolkit != 'true'
- name: Install dependencies
run: uv sync --extra all --extra dev
shell: bash

View file

@ -83,14 +83,8 @@ jobs:
uses: ./.github/actions/setup-uv-env
with:
python-version: "3.10"
is-toolkit: ${{ startsWith(matrix.package, 'toolkits/') }}
is-contrib: ${{ startsWith(matrix.package, 'contrib/') }}
is-lib: ${{ startsWith(matrix.package, 'libs/') }}
working-directory: ${{ matrix.package }}
- name: Run tests
# Skip tests for toolkits - tests are run on every PR commit for toolkits
if: ${{ !startsWith(matrix.package, 'toolkits/') }}
working-directory: ${{ matrix.package }}
run: |
# Run tests if they exist

View file

@ -1,132 +0,0 @@
name: Test Toolkits
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
env:
ARCADE_USAGE_TRACKING: "0"
jobs:
setup:
runs-on: ubuntu-latest
outputs:
toolkits_with_gha_secrets: ${{ steps.load_toolkits.outputs.toolkits_with_gha_secrets }}
toolkits_without_gha_secrets: ${{ steps.load_toolkits.outputs.toolkits_without_gha_secrets }}
steps:
- name: Check out
uses: actions/checkout@v4
- name: determine toolkits with and without GHA secrets
id: load_toolkits
run: |
# Find all directories in toolkits/ that have a pyproject.toml
TOOLKITS=$(find toolkits -maxdepth 1 -type d -not -name "toolkits" -exec test -f {}/pyproject.toml \; -exec basename {} \; | jq -R -s -c 'split("\n")[:-1]')
TOOLKITS_WITH_GHA_SECRETS='["postgres", "clickhouse", "mongodb"]'
TOOLKITS_WITHOUT_GHA_SECRETS=$(echo "$TOOLKITS" | jq -c --argjson with "$TOOLKITS_WITH_GHA_SECRETS" '[.[] | select(. as $t | $with | index($t) | not)]')
echo "Found toolkits: $TOOLKITS"
echo "Found toolkits without GHA secrets: $TOOLKITS_WITHOUT_GHA_SECRETS"
echo "Found toolkits with GHA secrets: $TOOLKITS_WITH_GHA_SECRETS"
echo "toolkits_without_gha_secrets=$TOOLKITS_WITHOUT_GHA_SECRETS" >> $GITHUB_OUTPUT
echo "toolkits_with_gha_secrets=$TOOLKITS_WITH_GHA_SECRETS" >> $GITHUB_OUTPUT
test-toolkits:
needs: setup
name: test-toolkits (${{ matrix.toolkit }}, ${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
toolkit: ${{ fromJson(needs.setup.outputs.toolkits_without_gha_secrets) }}
fail-fast: false
steps:
- name: Check out
uses: actions/checkout@v4
- name: Set up the environment
uses: ./.github/actions/setup-uv-env
with:
is-toolkit: "true"
working-directory: toolkits/${{ matrix.toolkit }}
- name: Install toolkit dependencies
working-directory: toolkits/${{ matrix.toolkit }}
shell: bash
run: uv pip install -e ".[dev]"
- name: Check toolkit
working-directory: toolkits/${{ matrix.toolkit }}
shell: bash
run: |
uv run --active pre-commit run -a
uv run --active mypy --config-file=pyproject.toml
- name: Test stand-alone toolkits (no secrets)
working-directory: toolkits/${{ matrix.toolkit }}
shell: bash
run: |
# Run pytest and capture exit code
uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$?
if [ "${EXIT_CODE:-0}" -eq 5 ]; then
echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..."
exit 0
elif [ "${EXIT_CODE:-0}" -ne 0 ]; then
exit ${EXIT_CODE}
fi
test-toolkits-with-gha-secrets:
needs: setup
# Linux-only: these toolkits bootstrap local DBs via docker/apt in tests/test_setup.sh.
runs-on: ubuntu-latest
strategy:
matrix:
toolkit: ${{ fromJson(needs.setup.outputs.toolkits_with_gha_secrets) }}
fail-fast: true
steps:
- name: Check out
uses: actions/checkout@v4
- name: Set up the environment
uses: ./.github/actions/setup-uv-env
with:
is-toolkit: "true"
working-directory: toolkits/${{ matrix.toolkit }}
- name: Install toolkit dependencies
working-directory: toolkits/${{ matrix.toolkit }}
run: uv pip install -e ".[dev]"
- name: Check toolkit
working-directory: toolkits/${{ matrix.toolkit }}
run: |
uv run --active pre-commit run -a
uv run --active mypy --config-file=pyproject.toml
- name: Test stand-alone toolkits (with secrets)
if: |
!github.event.pull_request.head.repo.fork
working-directory: toolkits/${{ matrix.toolkit }}
env:
TEST_POSTGRES_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_POSTGRES_DATABASE_CONNECTION_STRING }} # TODO: dynamically only load the `TEST_${{ matrix.toolkit }}_DATABASE_CONNECTION_STRING secret`
TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING }}
TEST_MONGODB_CONNECTION_STRING: ${{ secrets.TEST_MONGODB_CONNECTION_STRING }}
run: |
# If there's a custom test_setup.sh file, run it
if [ -f tests/test_setup.sh ]; then
echo "Running custom test setup for ${{ matrix.toolkit }}..."
./tests/test_setup.sh
fi
# Run pytest and capture exit code
uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$?
if [ "${EXIT_CODE:-0}" -eq 5 ]; then
echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..."
exit 0
elif [ "${EXIT_CODE:-0}" -ne 0 ]; then
exit ${EXIT_CODE}
fi

View file

@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
## What This Is
Arcade MCP is a Python tool-calling platform for building MCP (Model Context Protocol) servers. It's a monorepo containing 5 interdependent libraries, 30+ prebuilt toolkit integrations, and a CLI.
Arcade MCP is a Python tool-calling platform for building MCP (Model Context Protocol) servers. It's a monorepo containing 5 interdependent libraries and a CLI.
## Commands
@ -15,7 +15,6 @@ Arcade MCP is a Python tool-calling platform for building MCP (Model Context Pro
| Run a single test | `uv run pytest libs/tests/core/test_toolkit.py::TestClass::test_method` |
| Lint + type check | `make check` (pre-commit + mypy) |
| Build all wheels | `make build` |
| Run toolkit tests | `make test-toolkits` |
Package manager is **uv** — always use `uv run` to execute Python commands, never bare `pip` or `python`. Python 3.10+. Build system is Hatchling.

117
Makefile
View file

@ -6,37 +6,6 @@ install: ## Install the uv environment and all packages with dependencies
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv workspace"
.PHONY: install-toolkits
install-toolkits: ## Install dependencies for all toolkits
@echo "🚀 Installing dependencies for all toolkits"
@failed=0; \
successful=0; \
for dir in toolkits/*/ ; do \
if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \
echo "📦 Installing dependencies for $$dir"; \
if (cd $$dir && uv pip install -e ".[dev]"); then \
successful=$$((successful + 1)); \
else \
echo "❌ Failed to install dependencies for $$dir"; \
failed=$$((failed + 1)); \
fi; \
else \
echo "⚠️ Skipping $$dir (no pyproject.toml found)"; \
fi; \
done; \
echo ""; \
echo "📊 Installation Summary:"; \
echo " ✅ Successful: $$successful toolkits"; \
echo " ❌ Failed: $$failed toolkits"; \
if [ $$failed -gt 0 ]; then \
echo ""; \
echo "⚠️ Some toolkit installations failed. Check the output above for details."; \
exit 1; \
else \
echo ""; \
echo "🎉 All toolkit dependencies installed successfully!"; \
fi
.PHONY: check
check: ## Run code quality tools.
@echo "🚀 Linting code: Running pre-commit"
@ -56,18 +25,6 @@ check-libs: ## Run code quality tools for each lib package
(cd $$lib && uv run mypy . || true); \
done
.PHONY: check-toolkits
check-toolkits: ## Run code quality tools for each toolkit that has a Makefile
@echo "🚀 Running 'make check' in each toolkit with a Makefile"
@for dir in toolkits/*/ ; do \
if [ -f "$$dir/Makefile" ]; then \
echo "🛠️ Checking toolkit $$dir"; \
(cd "$$dir" && uv run pre-commit run -a && uv run mypy --config-file=pyproject.toml); \
else \
echo "🛠️ Skipping toolkit $$dir (no Makefile found)"; \
fi; \
done
.PHONY: test
test: install ## Test the code with pytest
@echo "🚀 Testing libs: Running pytest"
@ -81,15 +38,6 @@ test-libs: ## Test each lib package individually
(cd $$lib && uv run pytest -W ignore -v || true); \
done
.PHONY: test-toolkits
test-toolkits: ## Iterate over all toolkits and run pytest on each one
@echo "🚀 Testing code in toolkits: Running pytest"
@for dir in toolkits/*/ ; do \
toolkit_name=$$(basename "$$dir"); \
echo "🧪 Testing $$toolkit_name toolkit"; \
(cd $$dir && uv run pytest -W ignore -v --cov=arcade_$$toolkit_name --cov-report=xml || exit 1); \
done
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
@ -107,38 +55,6 @@ build: clean-build ## Build wheel files using uv
fi; \
done
.PHONY: build-toolkits
build-toolkits: ## Build wheel files for all toolkits
@echo "🚀 Creating wheel files for all toolkits"
@failed=0; \
successful=0; \
for dir in toolkits/*/ ; do \
if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \
toolkit_name=$$(basename "$$dir"); \
echo "🛠️ Building toolkit $$toolkit_name"; \
if (cd $$dir && uv build); then \
successful=$$((successful + 1)); \
else \
echo "❌ Failed to build toolkit $$toolkit_name"; \
failed=$$((failed + 1)); \
fi; \
else \
echo "⚠️ Skipping $$dir (no pyproject.toml found)"; \
fi; \
done; \
echo ""; \
echo "📊 Build Summary:"; \
echo " ✅ Successful: $$successful toolkits"; \
echo " ❌ Failed: $$failed toolkits"; \
if [ $$failed -gt 0 ]; then \
echo ""; \
echo "⚠️ Some toolkit builds failed. Check the output above for details."; \
exit 1; \
else \
echo ""; \
echo "🎉 All toolkit wheels built successfully!"; \
fi
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning build artifacts"
@ -161,30 +77,19 @@ build-and-publish: build publish ## Build and publish.
.PHONY: docker
docker: ## Build and run the Docker container
@echo "🚀 Building lib packages and toolkit wheels..."
@echo "🚀 Building lib packages..."
@make full-dist
@echo "🚀 Building Docker image"
@cd docker && make docker-build
@cd docker && make docker-run
.PHONY: docker-base
docker-base: ## Build and run the Docker container
@echo "🚀 Building lib packages and toolkit wheels..."
@make full-dist
@echo "🚀 Building Docker image"
@cd docker && INSTALL_TOOLKITS=false make docker-build
@cd docker && INSTALL_TOOLKITS=false make docker-run
.PHONY: publish-ghcr
publish-ghcr: ## Publish to the GHCR
# Publish the base image - ghcr.io/arcadeai/worker-base
@cd docker && INSTALL_TOOLKITS=false make publish-ghcr
# Publish the image with toolkits - ghcr.io/arcadeai/worker
@cd docker && INSTALL_TOOLKITS=true make publish-ghcr
@cd docker && make publish-ghcr
.PHONY: full-dist
full-dist: clean-dist ## Build all projects and copy wheels to ./dist
@echo "🛠️ Building a full distribution with lib packages and toolkits"
@echo "🛠️ Building a full distribution with lib packages"
@echo "🛠️ Building all lib packages and copying wheels to ./dist"
@mkdir -p dist
@ -198,16 +103,6 @@ full-dist: clean-dist ## Build all projects and copy wheels to ./dist
@uv build
@rm -f dist/*.tar.gz
@echo "🛠️ Building all toolkit packages and copying wheels to ./dist"
@for dir in toolkits/*/ ; do \
if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \
toolkit_name=$$(basename "$$dir"); \
echo "🛠️ Building toolkit $$toolkit_name wheel..."; \
(cd $$dir && uv build); \
cp $$dir/dist/*.whl dist/; \
fi; \
done
.PHONY: clean-dist
clean-dist: ## Clean all built distributions
@echo "🗑️ Cleaning dist directory"
@ -216,12 +111,6 @@ clean-dist: ## Clean all built distributions
@for lib in libs/arcade*/ ; do \
rm -rf "$$lib"/dist; \
done
@echo "🗑️ Cleaning toolkits/*/dist directory"
@for toolkit_dir in toolkits/*; do \
if [ -d "$$toolkit_dir" ]; then \
rm -rf "$$toolkit_dir"/dist; \
fi; \
done
.PHONY: setup
setup: ## Run uv environment setup script

View file

@ -1,18 +0,0 @@
files: ^arcade_brightdata/.*
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.4.0"
hooks:
- id: check-case-conflict
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.7
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View file

@ -1,44 +0,0 @@
target-version = "py310"
line-length = 100
fix = true
[lint]
select = [
# flake8-2020
"YTT",
# flake8-bandit
"S",
# flake8-bugbear
"B",
# flake8-builtins
"A",
# flake8-comprehensions
"C4",
# flake8-debugger
"T10",
# flake8-simplify
"SIM",
# isort
"I",
# mccabe
"C90",
# pycodestyle
"E", "W",
# pyflakes
"F",
# pygrep-hooks
"PGH",
# pyupgrade
"UP",
# ruff
"RUF",
# tryceratops
"TRY",
]
[lint.per-file-ignores]
"**/tests/*" = ["S101"]
[format]
preview = true
skip-magic-trailing-comma = false

View file

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2025, Arcade AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,55 +0,0 @@
.PHONY: help
help:
@echo "🛠️ github Commands:\n"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: install
install: ## Install the uv environment and install all packages with dependencies
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras --no-sources
@if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi
@echo "✅ All packages and dependencies installed via uv"
.PHONY: install-local
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras
@if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi
@echo "✅ All packages and dependencies installed via uv"
.PHONY: build
build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
uv build
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning dist directory"
rm -rf dist
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
@uv run --no-sources coverage report
@echo "Generating coverage report"
@uv run --no-sources coverage html
.PHONY: bump-version
bump-version: ## Bump the version in the pyproject.toml file by a patch version
@echo "🚀 Bumping version in pyproject.toml"
uv version --no-sources --bump patch
.PHONY: check
check: ## Run code quality tools.
@if [ -f .pre-commit-config.yaml ]; then\
echo "🚀 Linting code: Running pre-commit";\
uv run --no-sources pre-commit run -a;\
fi
@echo "🚀 Static type checking: Running mypy"
@uv run --no-sources mypy --config-file=pyproject.toml

View file

@ -1,3 +0,0 @@
from arcade_brightdata.tools import scrape_as_markdown, search_engine, web_data_feed
__all__ = ["scrape_as_markdown", "search_engine", "web_data_feed"]

View file

@ -1,28 +0,0 @@
import sys
from typing import cast
from arcade_mcp_server import MCPApp
from arcade_mcp_server.mcp_app import TransportType
import arcade_brightdata
app = MCPApp(
name="BrightData",
instructions=(
"Use this server when you need to interact with Bright Data to help users "
"scrape web pages, search the web, and extract structured data from websites."
),
)
app.add_tools_from_module(arcade_brightdata)
def main() -> None:
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
app.run(transport=cast(TransportType, transport), host=host, port=port)
if __name__ == "__main__":
main()

View file

@ -1,63 +0,0 @@
import json
from typing import ClassVar
from urllib.parse import quote
import requests
class BrightDataClient:
"""Engine for interacting with Bright Data API with connection management."""
_clients: ClassVar[dict[str, "BrightDataClient"]] = {}
def __init__(self, api_key: str, zone: str = "web_unlocker1") -> None:
"""
Initialize with API token and default zone.
Args:
api_key (str): Your Bright Data API token
zone (str): Bright Data zone name
"""
self.api_key = api_key
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
self.zone = zone
self.endpoint = "https://api.brightdata.com/request"
@classmethod
def create_client(cls, api_key: str, zone: str = "web_unlocker1") -> "BrightDataClient":
"""Create or get cached client instance using API key only."""
if api_key not in cls._clients:
cls._clients[api_key] = cls(api_key, zone)
# Update zone for this request (user controls zone per request)
client = cls._clients[api_key]
client.zone = zone
return client
@classmethod
def clear_cache(cls) -> None:
"""Clear the client cache."""
cls._clients.clear()
def make_request(self, payload: dict) -> str:
"""
Make a request to Bright Data API.
Args:
payload (Dict): Request payload
Returns:
str: Response text
"""
response = requests.post(
self.endpoint, headers=self.headers, data=json.dumps(payload), timeout=30
)
response.raise_for_status()
result: str = response.text
return result
@staticmethod
def encode_query(query: str) -> str:
"""URL encode a search query."""
return quote(query)

View file

@ -1,7 +0,0 @@
from arcade_brightdata.tools.bright_data_tools import (
scrape_as_markdown,
search_engine,
web_data_feed,
)
__all__ = ["scrape_as_markdown", "search_engine", "web_data_feed"]

View file

@ -1,361 +0,0 @@
import json
import time
from enum import Enum
from typing import Annotated, Any, cast
import requests
from arcade_mcp_server import Context, tool
from arcade_mcp_server.exceptions import RetryableToolError
from arcade_mcp_server.metadata import (
Behavior,
Classification,
Operation,
ServiceDomain,
ToolMetadata,
)
from arcade_brightdata.bright_data_client import BrightDataClient
class DeviceType(str, Enum):
MOBILE = "mobile"
IOS = "ios"
IPHONE = "iphone"
IPAD = "ipad"
ANDROID = "android"
ANDROID_TABLET = "android_tablet"
class SearchEngine(str, Enum):
GOOGLE = "google"
BING = "bing"
YANDEX = "yandex"
class SearchType(str, Enum):
IMAGES = "images"
SHOPPING = "shopping"
NEWS = "news"
JOBS = "jobs"
class SourceType(str, Enum):
AMAZON_PRODUCT = "amazon_product"
AMAZON_PRODUCT_REVIEWS = "amazon_product_reviews"
LINKEDIN_PERSON_PROFILE = "linkedin_person_profile"
LINKEDIN_COMPANY_PROFILE = "linkedin_company_profile"
ZOOMINFO_COMPANY_PROFILE = "zoominfo_company_profile"
INSTAGRAM_PROFILES = "instagram_profiles"
INSTAGRAM_POSTS = "instagram_posts"
INSTAGRAM_REELS = "instagram_reels"
INSTAGRAM_COMMENTS = "instagram_comments"
FACEBOOK_POSTS = "facebook_posts"
FACEBOOK_MARKETPLACE_LISTINGS = "facebook_marketplace_listings"
FACEBOOK_COMPANY_REVIEWS = "facebook_company_reviews"
X_POSTS = "x_posts"
ZILLOW_PROPERTIES_LISTING = "zillow_properties_listing"
BOOKING_HOTEL_LISTINGS = "booking_hotel_listings"
YOUTUBE_VIDEOS = "youtube_videos"
@tool(
requires_secrets=["BRIGHTDATA_API_KEY", "BRIGHTDATA_ZONE"],
metadata=ToolMetadata(
classification=Classification(
service_domains=[ServiceDomain.WEB_SCRAPING],
),
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
def scrape_as_markdown(
context: Context,
url: Annotated[str, "URL to scrape"],
) -> Annotated[str, "Scraped webpage content as Markdown"]:
"""
Scrape a webpage and return content in Markdown format using Bright Data.
Examples:
scrape_as_markdown("https://example.com") -> "# Example Page\n\nContent..."
scrape_as_markdown("https://news.ycombinator.com") -> "# Hacker News\n..."
"""
api_key = context.get_secret("BRIGHTDATA_API_KEY")
zone = context.get_secret("BRIGHTDATA_ZONE")
client = BrightDataClient.create_client(api_key=api_key, zone=zone)
payload = {"url": url, "zone": zone, "format": "raw", "data_format": "markdown"}
return client.make_request(payload)
@tool(
requires_secrets=["BRIGHTDATA_API_KEY", "BRIGHTDATA_ZONE"],
metadata=ToolMetadata(
classification=Classification(
service_domains=[ServiceDomain.WEB_SCRAPING],
),
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
def search_engine( # noqa: C901
context: Context,
query: Annotated[str, "Search query"],
engine: Annotated[SearchEngine, "Search engine to use"] = SearchEngine.GOOGLE,
language: Annotated[str | None, "Two-letter language code"] = None,
country_code: Annotated[str | None, "Two-letter country code"] = None,
search_type: Annotated[SearchType | None, "Type of search"] = None,
start: Annotated[int | None, "Results pagination offset"] = None,
num_results: Annotated[int, "Number of results to return. The default is 10"] = 10,
location: Annotated[str | None, "Location for search results"] = None,
device: Annotated[DeviceType | None, "Device type"] = None,
return_json: Annotated[bool, "Return JSON instead of Markdown"] = False,
) -> Annotated[str, "Search results as Markdown or JSON"]:
"""
Search using Google, Bing, or Yandex with advanced parameters using Bright Data.
Examples:
search_engine("climate change") -> "# Search Results\n\n## Climate Change - Wikipedia\n..."
search_engine("Python tutorials", engine="bing", num_results=5) -> "# Bing Results\n..."
search_engine("cats", search_type="images", country_code="us") -> "# Image Results\n..."
"""
api_key = context.get_secret("BRIGHTDATA_API_KEY")
zone = context.get_secret("BRIGHTDATA_ZONE")
client = BrightDataClient.create_client(api_key=api_key, zone=zone)
encoded_query = BrightDataClient.encode_query(query)
base_urls = {
SearchEngine.GOOGLE: f"https://www.google.com/search?q={encoded_query}",
SearchEngine.BING: f"https://www.bing.com/search?q={encoded_query}",
SearchEngine.YANDEX: f"https://yandex.com/search/?text={encoded_query}",
}
search_url = base_urls[engine]
if engine == SearchEngine.GOOGLE:
params = []
if language:
params.append(f"hl={language}")
if country_code:
params.append(f"gl={country_code}")
if search_type:
if search_type == SearchType.JOBS:
params.append("ibp=htl;jobs")
else:
search_types = {
SearchType.IMAGES: "isch",
SearchType.SHOPPING: "shop",
SearchType.NEWS: "nws",
}
tbm_value = search_types.get(search_type, search_type)
params.append(f"tbm={tbm_value}")
if start is not None:
params.append(f"start={start}")
if num_results:
params.append(f"num={num_results}")
if location:
params.append(f"uule={BrightDataClient.encode_query(location)}")
if device:
device_value = "1"
if device.value in ["ios", "iphone"]:
device_value = "ios"
elif device.value == "ipad":
device_value = "ios_tablet"
elif device.value == "android":
device_value = "android"
elif device.value == "android_tablet":
device_value = "android_tablet"
params.append(f"brd_mobile={device_value}")
if return_json:
params.append("brd_json=1")
if params:
search_url += "&" + "&".join(params)
payload = {
"url": search_url,
"zone": zone,
"format": "raw",
"data_format": "markdown" if not return_json else "raw",
}
return client.make_request(payload)
@tool(
requires_secrets=["BRIGHTDATA_API_KEY"],
metadata=ToolMetadata(
classification=Classification(
service_domains=[ServiceDomain.WEB_SCRAPING],
),
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=False,
open_world=True,
),
),
)
def web_data_feed(
context: Context,
source_type: Annotated[SourceType, "Type of data source"],
url: Annotated[str, "URL of the web resource to extract data from"],
num_of_reviews: Annotated[
int | None,
(
"Number of reviews to retrieve. Only applicable for "
"facebook_company_reviews. Default is None"
),
] = None,
timeout: Annotated[int, "Maximum time in seconds to wait for data retrieval"] = 600,
polling_interval: Annotated[int, "Time in seconds between polling attempts"] = 1,
) -> Annotated[str, "Structured data from the requested source as JSON"]:
"""
Extract structured data from various websites like LinkedIn, Amazon, Instagram, etc.
NEVER MADE UP LINKS - IF LINKS ARE NEEDED, EXECUTE search_engine FIRST.
Supported source types:
- amazon_product, amazon_product_reviews
- linkedin_person_profile, linkedin_company_profile
- zoominfo_company_profile
- instagram_profiles, instagram_posts, instagram_reels, instagram_comments
- facebook_posts, facebook_marketplace_listings, facebook_company_reviews
- x_posts
- zillow_properties_listing
- booking_hotel_listings
- youtube_videos
Examples:
web_data_feed("amazon_product", "https://amazon.com/dp/B08N5WRWNW")
-> "{\"title\": \"Product Name\", ...}"
web_data_feed("linkedin_person_profile", "https://linkedin.com/in/johndoe")
-> "{\"name\": \"John Doe\", ...}"
web_data_feed(
"facebook_company_reviews", "https://facebook.com/company", num_of_reviews=50
) -> "[{\"review\": \"...\", ...}]"
"""
api_key = context.get_secret("BRIGHTDATA_API_KEY")
client = BrightDataClient.create_client(api_key=api_key)
if num_of_reviews is not None and source_type != SourceType.FACEBOOK_COMPANY_REVIEWS:
msg = (
f"num_of_reviews parameter is only applicable for facebook_company_reviews, "
f"not for {source_type.value}"
)
prompt = (
"The num_of_reviews parameter should only be used with "
"facebook_company_reviews source type."
)
raise RetryableToolError(msg, additional_prompt_content=prompt)
data = _extract_structured_data(
client=client,
source_type=source_type,
url=url,
num_of_reviews=num_of_reviews,
timeout=timeout,
polling_interval=polling_interval,
)
return json.dumps(data, indent=2)
def _extract_structured_data(
client: BrightDataClient,
source_type: SourceType,
url: str,
num_of_reviews: int | None = None,
timeout: int = 600,
polling_interval: int = 1,
) -> dict[str, Any]:
"""
Extract structured data from various sources.
"""
datasets = {
SourceType.AMAZON_PRODUCT: "gd_l7q7dkf244hwjntr0",
SourceType.AMAZON_PRODUCT_REVIEWS: "gd_le8e811kzy4ggddlq",
SourceType.LINKEDIN_PERSON_PROFILE: "gd_l1viktl72bvl7bjuj0",
SourceType.LINKEDIN_COMPANY_PROFILE: "gd_l1vikfnt1wgvvqz95w",
SourceType.ZOOMINFO_COMPANY_PROFILE: "gd_m0ci4a4ivx3j5l6nx",
SourceType.INSTAGRAM_PROFILES: "gd_l1vikfch901nx3by4",
SourceType.INSTAGRAM_POSTS: "gd_lk5ns7kz21pck8jpis",
SourceType.INSTAGRAM_REELS: "gd_lyclm20il4r5helnj",
SourceType.INSTAGRAM_COMMENTS: "gd_ltppn085pokosxh13",
SourceType.FACEBOOK_POSTS: "gd_lyclm1571iy3mv57zw",
SourceType.FACEBOOK_MARKETPLACE_LISTINGS: "gd_lvt9iwuh6fbcwmx1a",
SourceType.FACEBOOK_COMPANY_REVIEWS: "gd_m0dtqpiu1mbcyc2g86",
SourceType.X_POSTS: "gd_lwxkxvnf1cynvib9co",
SourceType.ZILLOW_PROPERTIES_LISTING: "gd_lfqkr8wm13ixtbd8f5",
SourceType.BOOKING_HOTEL_LISTINGS: "gd_m5mbdl081229ln6t4a",
SourceType.YOUTUBE_VIDEOS: "gd_m5mbdl081229ln6t4a",
}
dataset_id = datasets[source_type]
request_data = {"url": url}
if source_type == SourceType.FACEBOOK_COMPANY_REVIEWS and num_of_reviews is not None:
request_data["num_of_reviews"] = str(num_of_reviews)
trigger_response = requests.post(
"https://api.brightdata.com/datasets/v3/trigger",
params={"dataset_id": dataset_id, "include_errors": "true"},
headers=client.headers,
json=[request_data],
timeout=30,
)
trigger_data = trigger_response.json()
if not trigger_data.get("snapshot_id"):
msg = "No snapshot ID returned from trigger request"
prompt = "Invalid input provided, use search_engine to get the relevant data first"
raise RetryableToolError(msg, additional_prompt_content=prompt)
snapshot_id = trigger_data["snapshot_id"]
attempts = 0
max_attempts = timeout
while attempts < max_attempts:
try:
snapshot_response = requests.get(
f"https://api.brightdata.com/datasets/v3/snapshot/{snapshot_id}",
params={"format": "json"},
headers=client.headers,
timeout=30,
)
snapshot_data = cast(dict[str, Any], snapshot_response.json())
if isinstance(snapshot_data, dict) and snapshot_data.get("status") in (
"running",
"building",
):
attempts += 1
time.sleep(polling_interval)
continue
else:
return snapshot_data
except Exception:
attempts += 1
time.sleep(polling_interval)
msg = f"Timeout after {max_attempts} seconds waiting for {source_type.value} data"
raise TimeoutError(msg)

View file

@ -1,62 +0,0 @@
[build-system]
requires = [ "hatchling",]
build-backend = "hatchling.build"
[project]
name = "arcade_brightdata"
version = "0.4.0"
description = "Search, Crawl and Scrape any site, at scale, without getting blocked"
requires-python = ">=3.10"
dependencies = [
"arcade-mcp-server>=1.17.0,<2.0.0",
"requests>=2.32.5",
]
[[project.authors]]
name = "meirk-brd"
email = "meirk@brightdata.com"
[project.scripts]
arcade-brightdata = "arcade_brightdata.__main__:main"
arcade_brightdata = "arcade_brightdata.__main__:main"
[project.optional-dependencies]
dev = [
"arcade-mcp[all]>=1.2.0,<2.0.0",
"pytest>=8.3.0,<8.4.0",
"pytest-cov>=4.0.0,<4.1.0",
"pytest-mock>=3.11.1,<3.12.0",
"pytest-asyncio>=0.24.0,<0.25.0",
"mypy>=1.5.1,<1.6.0",
"pre-commit>=3.4.0,<3.5.0",
"tox>=4.11.1,<4.12.0",
"ruff>=0.7.4,<0.8.0",
"types-requests>=2.32.0",
]
# Tell Arcade.dev that this package is a toolkit
[project.entry-points.arcade_toolkits]
toolkit_name = "arcade_brightdata"
[tool.mypy]
files = [ "arcade_brightdata/**/*.py",]
python_version = "3.10"
disallow_untyped_defs = "True"
disallow_any_unimported = "True"
no_implicit_optional = "True"
check_untyped_defs = "True"
warn_return_any = "True"
warn_unused_ignores = "True"
show_error_codes = "True"
ignore_missing_imports = "True"
[tool.uv.sources]
arcade-mcp = { path = "../../", editable = true }
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
[tool.pytest.ini_options]
testpaths = [ "tests",]
[tool.coverage.report]
skip_empty = true
[tool.hatch.build.targets.wheel]
packages = [ "arcade_brightdata",]

View file

@ -1,418 +0,0 @@
from os import environ
from unittest.mock import MagicMock as _MagicMock
from unittest.mock import Mock, patch
import pytest
import requests
from arcade_mcp_server import Context
from arcade_mcp_server.exceptions import ToolExecutionError
from arcade_brightdata.bright_data_client import BrightDataClient
from arcade_brightdata.tools.bright_data_tools import (
DeviceType,
SourceType,
scrape_as_markdown,
search_engine,
web_data_feed,
)
BRIGHTDATA_API_KEY = environ.get("TEST_BRIGHTDATA_API_KEY") or "api-key"
BRIGHTDATA_ZONE = environ.get("TEST_BRIGHTDATA_ZONE") or "unblocker"
@pytest.fixture
def mock_context():
context = _MagicMock(spec=Context)
context.get_secret = _MagicMock(
side_effect=lambda key: {
"BRIGHTDATA_API_KEY": BRIGHTDATA_API_KEY,
"BRIGHTDATA_ZONE": BRIGHTDATA_ZONE,
}[key]
)
return context
@pytest.fixture(autouse=True)
def cleanup_engines():
"""Clean up bright data clients after each test to prevent connection leaks."""
yield
BrightDataClient.clear_cache()
class TestBrightDataClient:
def test_get_instance_creates_new_client(self):
client1 = BrightDataClient.create_client("test_key_1", "zone1")
client2 = BrightDataClient.create_client("test_key_2", "zone2")
assert client1 != client2
assert client1.api_key == "test_key_1"
assert client1.zone == "zone1"
assert client2.api_key == "test_key_2"
assert client2.zone == "zone2"
def test_get_instance_returns_cached_client(self):
client1 = BrightDataClient.create_client("test_key", "zone1")
client2 = BrightDataClient.create_client("test_key", "zone1")
assert client1 is client2
def test_clear_cache(self):
client1 = BrightDataClient.create_client("test_key", "zone1")
BrightDataClient.clear_cache()
client2 = BrightDataClient.create_client("test_key", "zone1")
assert client1 is not client2
def test_encode_query(self):
result = BrightDataClient.encode_query("hello world test")
assert result == "hello%20world%20test"
@patch("requests.post")
def test_make_request_success(self, mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "Success response"
mock_post.return_value = mock_response
client = BrightDataClient("test_key", "test_zone")
result = client.make_request({"url": "https://example.com"})
assert result == "Success response"
mock_post.assert_called_once()
@patch("requests.post")
def test_make_request_failure(self, mock_post):
mock_response = Mock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
"400 Client Error"
)
mock_post.return_value = mock_response
client = BrightDataClient("test_key", "test_zone")
with pytest.raises(requests.exceptions.HTTPError):
client.make_request({"url": "https://example.com"})
class TestScrapeAsMarkdown:
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
def test_scrape_as_markdown_success(self, mock_engine_class, mock_context):
mock_client = Mock()
mock_client.make_request.return_value = "# Test Page\n\nContent here"
mock_engine_class.create_client.return_value = mock_client
result = scrape_as_markdown(mock_context, "https://example.com")
assert result == "# Test Page\n\nContent here"
mock_engine_class.create_client.assert_called_once_with(
api_key=BRIGHTDATA_API_KEY, zone=BRIGHTDATA_ZONE
)
mock_client.make_request.assert_called_once_with({
"url": "https://example.com",
"zone": BRIGHTDATA_ZONE,
"format": "raw",
"data_format": "markdown",
})
class TestSearchEngine:
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
def test_search_engine_google_basic(self, mock_engine_class, mock_context):
mock_client = Mock()
mock_client.make_request.return_value = "# Search Results\n\nResult 1\nResult 2"
mock_engine_class.create_client.return_value = mock_client
mock_engine_class.encode_query.return_value = "test%20query"
result = search_engine(mock_context, "test query")
assert result == "# Search Results\n\nResult 1\nResult 2"
mock_engine_class.create_client.assert_called_once_with(
api_key=BRIGHTDATA_API_KEY, zone=BRIGHTDATA_ZONE
)
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
def test_search_engine_bing(self, mock_engine_class, mock_context):
mock_client = Mock()
mock_client.make_request.return_value = "# Bing Results"
mock_engine_class.create_client.return_value = mock_client
mock_engine_class.encode_query.return_value = "test%20query"
result = search_engine(mock_context, "test query", engine="bing")
assert result == "# Bing Results"
expected_payload = {
"url": "https://www.bing.com/search?q=test%20query",
"zone": BRIGHTDATA_ZONE,
"format": "raw",
"data_format": "markdown",
}
mock_client.make_request.assert_called_once_with(expected_payload)
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
def test_search_engine_google_with_parameters(self, mock_engine_class, mock_context):
mock_client = Mock()
mock_client.make_request.return_value = "# Google Results with params"
mock_engine_class.create_client.return_value = mock_client
mock_engine_class.encode_query.side_effect = lambda x: x.replace(" ", "%20")
result = search_engine(
mock_context,
"test query",
language="en",
country_code="us",
search_type="images",
start=10,
num_results=20,
location="New York",
device=DeviceType.MOBILE,
return_json=True,
)
assert result == "# Google Results with params"
call_args = mock_client.make_request.call_args[0][0]
assert "hl=en" in call_args["url"]
assert "gl=us" in call_args["url"]
assert "tbm=isch" in call_args["url"]
assert "start=10" in call_args["url"]
assert "num=20" in call_args["url"]
assert "brd_mobile=1" in call_args["url"]
assert "brd_json=1" in call_args["url"]
assert call_args["data_format"] == "raw"
def test_search_engine_invalid_engine(self, mock_context):
with pytest.raises(ToolExecutionError):
search_engine(mock_context, "test query", engine="invalid_engine")
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
def test_search_engine_google_jobs(self, mock_engine_class, mock_context):
mock_client = Mock()
mock_client.make_request.return_value = "# Job Results"
mock_engine_class.create_client.return_value = mock_client
mock_engine_class.encode_query.return_value = "python%20developer"
result = search_engine(mock_context, "python developer", search_type="jobs")
assert result == "# Job Results"
call_args = mock_client.make_request.call_args[0][0]
assert "ibp=htl;jobs" in call_args["url"]
class TestWebDataFeed:
@patch("arcade_brightdata.tools.bright_data_tools._extract_structured_data")
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
def test_web_data_feed_success(self, mock_engine_class, mock_extract, mock_context):
mock_client = Mock()
mock_engine_class.create_client.return_value = mock_client
mock_extract.return_value = {"title": "Test Product", "price": "$19.99"}
result = web_data_feed(mock_context, "amazon_product", "https://amazon.com/dp/B08N5WRWNW")
expected_json = '{\n "title": "Test Product",\n "price": "$19.99"\n}'
assert result == expected_json
mock_engine_class.create_client.assert_called_once_with(api_key=BRIGHTDATA_API_KEY)
mock_extract.assert_called_once_with(
client=mock_client,
source_type=SourceType.AMAZON_PRODUCT,
url="https://amazon.com/dp/B08N5WRWNW",
num_of_reviews=None,
timeout=600,
polling_interval=1,
)
@patch("arcade_brightdata.tools.bright_data_tools._extract_structured_data")
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
def test_web_data_feed_with_reviews(self, mock_engine_class, mock_extract, mock_context):
mock_client = Mock()
mock_engine_class.create_client.return_value = mock_client
mock_extract.return_value = [{"review": "Great product!", "rating": 5}]
result = web_data_feed(
mock_context,
"facebook_company_reviews",
"https://facebook.com/company",
num_of_reviews=50,
timeout=300,
polling_interval=2,
)
expected_json = '[\n {\n "review": "Great product!",\n "rating": 5\n }\n]'
assert result == expected_json
mock_extract.assert_called_once_with(
client=mock_client,
source_type=SourceType.FACEBOOK_COMPANY_REVIEWS,
url="https://facebook.com/company",
num_of_reviews=50,
timeout=300,
polling_interval=2,
)
class TestExtractStructuredData:
@patch("requests.get")
@patch("requests.post")
def test_extract_structured_data_success(self, mock_post, mock_get):
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
client = BrightDataClient("test_key", "test_zone")
mock_trigger_response = Mock()
mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"}
mock_post.return_value = mock_trigger_response
mock_snapshot_response = Mock()
mock_snapshot_response.json.return_value = {"data": "extracted_data"}
mock_get.return_value = mock_snapshot_response
result = _extract_structured_data(
client=client,
source_type=SourceType.AMAZON_PRODUCT,
url="https://amazon.com/dp/TEST",
timeout=10,
polling_interval=0.1,
)
assert result == {"data": "extracted_data"}
mock_post.assert_called_once()
trigger_call = mock_post.call_args
assert "gd_l7q7dkf244hwjntr0" in str(trigger_call) # Amazon product dataset ID
mock_get.assert_called_once()
snapshot_call = mock_get.call_args
assert "snap_123" in str(snapshot_call)
@patch("requests.get")
@patch("requests.post")
def test_extract_structured_data_with_polling(self, mock_post, mock_get):
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
client = BrightDataClient("test_key", "test_zone")
mock_trigger_response = Mock()
mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"}
mock_post.return_value = mock_trigger_response
running_response = Mock()
running_response.json.return_value = {"status": "running"}
complete_response = Mock()
complete_response.json.return_value = {"data": "final_data"}
mock_get.side_effect = [running_response, complete_response]
result = _extract_structured_data(
client=client,
source_type=SourceType.LINKEDIN_PERSON_PROFILE,
url="https://linkedin.com/in/test",
timeout=10,
polling_interval=0.1,
)
assert result == {"data": "final_data"}
assert mock_get.call_count == 2
@patch("requests.post")
def test_extract_structured_data_invalid_source_type(self, mock_post):
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
client = BrightDataClient("test_key", "test_zone")
# Create a mock SourceType that doesn't exist in the datasets dict
class InvalidSourceType:
value = "invalid_source"
with pytest.raises(KeyError):
_extract_structured_data(
client=client, source_type=InvalidSourceType(), url="https://example.com"
)
@patch("requests.get")
@patch("requests.post")
def test_extract_structured_data_no_snapshot_id(self, mock_post, mock_get):
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
client = BrightDataClient("test_key", "test_zone")
# Mock trigger response without snapshot_id
mock_trigger_response = Mock()
mock_trigger_response.json.return_value = {}
mock_post.return_value = mock_trigger_response
with pytest.raises(Exception) as exc_info:
_extract_structured_data(
client=client,
source_type=SourceType.AMAZON_PRODUCT,
url="https://amazon.com/dp/TEST",
)
assert "No snapshot ID returned from trigger request" in str(exc_info.value)
@patch("requests.get")
@patch("requests.post")
@patch("time.sleep")
def test_extract_structured_data_timeout(self, mock_sleep, mock_post, mock_get):
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
client = BrightDataClient("test_key", "test_zone")
# Mock trigger response
mock_trigger_response = Mock()
mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"}
mock_post.return_value = mock_trigger_response
# Mock snapshot response that always returns running
mock_snapshot_response = Mock()
mock_snapshot_response.json.return_value = {"status": "running"}
mock_get.return_value = mock_snapshot_response
with pytest.raises(TimeoutError) as exc_info:
_extract_structured_data(
client=client,
source_type=SourceType.AMAZON_PRODUCT,
url="https://amazon.com/dp/TEST",
timeout=2,
polling_interval=0.1,
)
assert "Timeout after 2 seconds waiting for amazon_product data" in str(exc_info.value)
class TestIntegration:
"""Integration tests that test the full flow without mocking internal components."""
@patch("requests.post")
def test_scrape_as_markdown_integration(self, mock_post, mock_context):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "# Integration Test\n\nThis is a test page"
mock_post.return_value = mock_response
result = scrape_as_markdown(mock_context, "https://example.com")
assert result == "# Integration Test\n\nThis is a test page"
# Verify the request was made correctly
call_args = mock_post.call_args
assert call_args[1]["headers"]["Authorization"] == f"Bearer {BRIGHTDATA_API_KEY}"
assert "https://api.brightdata.com/request" in str(call_args)
@patch("requests.post")
def test_search_engine_integration(self, mock_post, mock_context):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "# Search Results\n\n1. First result\n2. Second result"
mock_post.return_value = mock_response
result = search_engine(mock_context, "test query", engine="google")
assert result == "# Search Results\n\n1. First result\n2. Second result"
call_args = mock_post.call_args
payload = call_args[1]["data"]
assert '"url": "https://www.google.com/search?q=test%20query' in payload
assert '"data_format": "markdown"' in payload

View file

@ -1,53 +0,0 @@
.PHONY: help
help:
@echo "🛠️ clickhouse Commands:\n"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: install
install: ## Install the uv environment and install all packages with dependencies
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras --no-sources
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv"
.PHONY: install-local
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv"
.PHONY: build
build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
uv build
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning dist directory"
rm -rf dist
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
coverage report
@echo "Generating coverage report"
coverage html
.PHONY: bump-version
bump-version: ## Bump the version in the pyproject.toml file by a patch version
@echo "🚀 Bumping version in pyproject.toml"
uv version --bump patch
.PHONY: check
check: ## Run code quality tools.
@echo "🚀 Linting code: Running pre-commit"
@uv run pre-commit run -a
@echo "🚀 Static type checking: Running mypy"
@uv run mypy --config-file=pyproject.toml

View file

@ -1,29 +0,0 @@
import sys
from typing import cast
from arcade_mcp_server import MCPApp
from arcade_mcp_server.mcp_app import TransportType
import arcade_clickhouse
app = MCPApp(
name="ClickHouse",
instructions=(
"Use this server when you need to interact with ClickHouse to help users "
"query, explore, and manage their ClickHouse databases."
),
)
app.add_tools_from_module(arcade_clickhouse)
def main() -> None:
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
app.run(transport=cast(TransportType, transport), host=host, port=port)
if __name__ == "__main__":
main()

View file

@ -1,209 +0,0 @@
import contextlib
from typing import Any, ClassVar
from urllib.parse import urlparse
import clickhouse_connect
from arcade_mcp_server.exceptions import RetryableToolError
MAX_ROWS_RETURNED = 1000
TEST_QUERY = "SELECT 1"
class DatabaseEngine:
_instance: ClassVar[None] = None
_clients: ClassVar[dict[str, Any]] = {}
@classmethod
async def get_instance(cls, connection_string: str) -> Any:
parsed_url = urlparse(connection_string)
# Extract connection parameters from the URL
host = parsed_url.hostname or "localhost"
port = parsed_url.port
database = parsed_url.path.lstrip("/") or "default"
username = parsed_url.username
password = parsed_url.password
# Handle different ClickHouse protocols
# clickhouse-connect only supports HTTP and HTTPS interfaces
if parsed_url.scheme in ["clickhouse+native"]:
# Convert native protocol to HTTP for clickhouse-connect compatibility
# Convert native port 9000 to HTTP port 8123
port = 8123 if port == 9000 else port or 8123
interface = "http"
elif parsed_url.scheme in ["clickhouse+https"]:
# For HTTPS protocol
port = port or 8443
interface = "https"
else:
# For HTTP or unspecified, use port 8123 by default
port = port or 8123
interface = "http"
key = f"{interface}://{host}:{port}/{database}"
if key not in cls._clients:
try:
# Create ClickHouse client
client_args: dict[str, Any] = {
"host": host,
"port": port,
"database": database,
"interface": interface,
}
if username:
client_args["username"] = username
if password:
client_args["password"] = password
client = clickhouse_connect.get_client(**client_args)
cls._clients[key] = client
# Test the connection
client.command(TEST_QUERY)
except Exception as e:
# Remove failed client from cache
cls._clients.pop(key, None)
raise RetryableToolError(
f"Connection failed: {e}",
developer_message="Connection to ClickHouse failed.",
additional_prompt_content="Check the connection string and try again.",
) from e
return cls._clients[key]
@classmethod
async def get_engine(cls, connection_string: str) -> Any:
client = await cls.get_instance(connection_string)
class ConnectionContextManager:
def __init__(self, client: Any) -> None:
self.client = client
async def __aenter__(self) -> Any:
return self.client
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# Connection cleanup is handled by clickhouse-connect
pass
return ConnectionContextManager(client)
@classmethod
async def cleanup(cls) -> None:
"""Clean up all cached clients. Call this when shutting down."""
for client in cls._clients.values():
with contextlib.suppress(Exception):
client.close()
cls._clients.clear()
@classmethod
def clear_cache(cls) -> None:
"""Clear the client cache without disposing clients. Use with caution."""
cls._clients.clear()
@classmethod
def sanitize_query( # noqa: C901
cls,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> tuple[str, dict[str, Any]]:
# Remove the leading keywords from the clauses if they are present
if select_clause.strip().split(" ")[0].upper() == "SELECT":
select_clause = select_clause.strip()[6:]
if from_clause.strip().split(" ")[0].upper() == "FROM":
from_clause = from_clause.strip()[4:]
if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN":
join_clause = join_clause.strip()[4:]
if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE":
where_clause = where_clause.strip()[5:]
if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY":
group_by_clause = group_by_clause.strip()[8:]
if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY":
order_by_clause = order_by_clause.strip()[8:]
if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING":
having_clause = having_clause.strip()[6:]
first_select_word = select_clause.strip().split(" ")[0].upper()
if first_select_word in [
"INSERT",
"UPDATE",
"DELETE",
"CREATE",
"ALTER",
"DROP",
"TRUNCATE",
"REINDEX",
"VACUUM",
"ANALYZE",
"COMMENT",
"OPTIMIZE", # ClickHouse-specific
"SYSTEM", # ClickHouse-specific
]:
raise RetryableToolError(
"Only SELECT queries are allowed.",
)
if select_clause.strip() == "*":
raise RetryableToolError(
"Do not use * in the select clause. Use a comma separated list of columns you wish to return.",
)
if limit > MAX_ROWS_RETURNED:
raise RetryableToolError(
f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.",
)
if offset < 0:
raise RetryableToolError(
"Offset must be greater than or equal to 0.",
developer_message="Offset must be greater than or equal to 0.",
)
if limit <= 0:
raise RetryableToolError(
"Limit must be greater than 0.",
developer_message="Limit must be greater than 0.",
)
# Build query with identifiers directly interpolated, but use parameters for values
parts = []
if with_clause:
parts.append(f"WITH {with_clause}")
parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608
if join_clause:
parts.append(f"JOIN {join_clause}")
if where_clause:
parts.append(f"WHERE {where_clause}")
if group_by_clause:
parts.append(f"GROUP BY {group_by_clause}")
if having_clause:
parts.append(f"HAVING {having_clause}")
if order_by_clause:
parts.append(f"ORDER BY {order_by_clause}")
parts.append("LIMIT :limit OFFSET :offset")
query = " ".join(parts)
# Only use parameters for values, not identifiers
parameters = {
"limit": limit,
"offset": offset,
}
return query, parameters

View file

@ -1,347 +0,0 @@
from typing import Annotated, Any
from arcade_mcp_server import Context, tool
from arcade_mcp_server.exceptions import RetryableToolError
from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata
from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine
@tool(
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def discover_schemas(
context: Context,
) -> list[str]:
"""Discover all the schemas in the ClickHouse database.
Note: ClickHouse doesn't have schemas like PostgreSQL, so this returns a default schema name.
"""
# ClickHouse doesn't have schemas like PostgreSQL, but we return a default for compatibility
return ["default"]
@tool(
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def discover_databases(
context: Context,
) -> list[str]:
"""Discover all the databases in the ClickHouse database."""
async with await DatabaseEngine.get_engine(
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
) as client:
databases = await _get_databases(client)
return databases
@tool(
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def discover_tables(
context: Context,
) -> list[str]:
"""Discover all the tables in the ClickHouse database when the list of tables is not known.
ALWAYS use this tool before any other tool that requires a table name.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
) as client:
tables = await _get_tables(client, "default")
return tables
@tool(
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def get_table_schema(
context: Context,
schema_name: Annotated[str, "The schema to get the table schema of"],
table_name: Annotated[str, "The table to get the schema of"],
) -> list[str]:
"""
Get the schema/structure of a ClickHouse table in the ClickHouse database when the schema is not known, and the name of the table is provided.
This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the <DiscoverTables> tool.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
) as client:
return await _get_table_schema(client, "default", table_name)
@tool(
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def execute_select_query(
context: Context,
select_clause: Annotated[
str,
"This is the part of the SQL query that comes after the SELECT keyword wish a comma separated list of columns you wish to return. Do not include the SELECT keyword.",
],
from_clause: Annotated[
str,
"This is the part of the SQL query that comes after the FROM keyword. Do not include the FROM keyword.",
],
limit: Annotated[
int,
"The maximum number of rows to return. This is the LIMIT clause of the query. Default: 100.",
] = 100,
offset: Annotated[
int, "The number of rows to skip. This is the OFFSET clause of the query. Default: 0."
] = 0,
join_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the JOIN keyword. Do not include the JOIN keyword. If no join is needed, leave this blank.",
] = None,
where_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the WHERE keyword. Do not include the WHERE keyword. If no where clause is needed, leave this blank.",
] = None,
having_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the HAVING keyword. Do not include the HAVING keyword. If no having clause is needed, leave this blank.",
] = None,
group_by_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the GROUP BY keyword. Do not include the GROUP BY keyword. If no group by clause is needed, leave this blank.",
] = None,
order_by_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the ORDER BY keyword. Do not include the ORDER BY keyword. If no order by clause is needed, leave this blank.",
] = None,
with_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the WITH keyword when basing the query on a virtual table. If no WITH clause is needed, leave this blank.",
] = None,
) -> list[str]:
"""
You have a connection to a ClickHouse database.
Execute a SELECT query and return the results against the ClickHouse database. No other queries (INSERT, UPDATE, DELETE, etc.) are allowed.
ONLY use this tool if you have already loaded the schema of the tables you need to query. Use the <GetTableSchema> tool to load the schema if not already known.
The final query will be constructed as follows:
SELECT {select_query_part} FROM {from_clause} JOIN {join_clause} WHERE {where_clause} HAVING {having_clause} ORDER BY {order_by_clause} LIMIT {limit} OFFSET {offset}
When running queries, follow these rules which will help avoid errors:
* Never "select *" from a table. Always select the columns you need.
* Always order your results by the most important columns first. If you aren't sure, order by the primary key.
* Always use case-insensitive queries to match strings in the query.
* Always trim strings in the query.
* Prefer LIKE queries over direct string matches or regex queries.
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
* ClickHouse is case-sensitive, so be careful with table and column names.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
) as client:
try:
return await _execute_query(
client,
select_clause=select_clause,
from_clause=from_clause,
limit=limit,
offset=offset,
join_clause=join_clause,
where_clause=where_clause,
having_clause=having_clause,
group_by_clause=group_by_clause,
order_by_clause=order_by_clause,
with_clause=with_clause,
)
except Exception as e:
raise RetryableToolError(
f"Query failed: {e}",
developer_message=f"Query failed with parameters: select_clause={select_clause}, from_clause={from_clause}, limit={limit}, offset={offset}, join_clause={join_clause}, where_clause={where_clause}, having_clause={having_clause}, order_by_clause={order_by_clause}, with_clause={with_clause}.",
additional_prompt_content="Load the database schema <GetTableSchema> or use the <DiscoverTables> tool to discover the tables and try again.",
retry_after_ms=10,
) from e
async def _get_databases(client: Any) -> list[str]:
"""Get all the databases in ClickHouse"""
# ClickHouse uses SHOW DATABASES instead of information_schema
result = client.query("SHOW DATABASES")
databases = [row[0] for row in result.result_rows]
# Filter out system databases
system_databases = {
"system",
"information_schema",
"INFORMATION_SCHEMA",
"default",
"temporary_tables",
"temporary_tables_metadata",
}
databases = [db for db in databases if db not in system_databases]
databases.sort()
return databases
async def _get_tables(client: Any, database_name: str) -> list[str]:
"""Get all the tables in the specified ClickHouse database"""
# ClickHouse uses SHOW TABLES FROM database_name
result = client.query(f"SHOW TABLES FROM {database_name}")
tables = [row[0] for row in result.result_rows]
tables.sort()
return tables
async def _get_table_schema(client: Any, database_name: str, table_name: str) -> list[str]:
"""Get the schema of a ClickHouse table"""
# ClickHouse uses DESCRIBE TABLE database_name.table_name
result = client.query(f"DESCRIBE TABLE {database_name}.{table_name}")
columns = result.result_rows
# Get primary key information
# ClickHouse doesn't have traditional primary keys like PostgreSQL
# Instead, it has sorting keys and primary keys that are part of the table engine
try:
pk_result = client.query(f"SHOW CREATE TABLE {database_name}.{table_name}")
if pk_result.result_rows:
create_statement = pk_result.result_rows[0][0]
# Parse the CREATE statement to extract primary key information
primary_keys = _extract_primary_keys_from_create_statement(create_statement)
else:
primary_keys = set()
except Exception:
primary_keys = set()
results = []
for column in columns:
column_name = column[
0
] # ClickHouse DESCRIBE returns: name, type, default_type, default_expression, comment, codec_expression, ttl_expression
column_type = column[1]
# Build column description
description = f"{column_name}: {column_type}"
# Add primary key indicator
if column_name in primary_keys:
description += " (PRIMARY KEY)"
# Add default value if present
if len(column) > 3 and column[3]: # default_expression
description += f" DEFAULT {column[3]}"
# Add comment if present
if len(column) > 4 and column[4]: # comment
description += f" COMMENT '{column[4]}'"
results.append(description)
return results[:MAX_ROWS_RETURNED]
def _extract_primary_keys_from_create_statement(create_statement: str) -> set[str]:
"""Extract primary key columns from ClickHouse CREATE TABLE statement"""
primary_keys = set()
# Look for PRIMARY KEY clause
import re
pk_match = re.search(r"PRIMARY KEY\s*\(([^)]+)\)", create_statement, re.IGNORECASE)
if pk_match:
pk_columns = pk_match.group(1).split(",")
for col in pk_columns:
primary_keys.add(col.strip().strip("`"))
# Look for ORDER BY clause (which can also indicate primary key)
order_match = re.search(r"ORDER BY\s*\(([^)]+)\)", create_statement, re.IGNORECASE)
if order_match:
order_columns = order_match.group(1).split(",")
for col in order_columns:
primary_keys.add(col.strip().strip("`"))
return primary_keys
async def _execute_query(
client: Any,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> list[str]:
"""Execute a query and return the results."""
query, parameters = DatabaseEngine.sanitize_query(
select_clause=select_clause,
from_clause=from_clause,
limit=limit,
offset=offset,
join_clause=join_clause,
where_clause=where_clause,
having_clause=having_clause,
group_by_clause=group_by_clause,
order_by_clause=order_by_clause,
with_clause=with_clause,
)
print(f"Query: {query}")
print(f"Parameters: {parameters}")
# For clickhouse-connect, we need to substitute parameters manually
# since it doesn't use SQLAlchemy-style parameter binding
formatted_query = query
for param_name, param_value in parameters.items():
formatted_query = formatted_query.replace(f":{param_name}", str(param_value))
result = client.query(formatted_query)
rows = result.result_rows
results = [str(row) for row in rows]
return results[:MAX_ROWS_RETURNED]

View file

@ -1,66 +0,0 @@
[build-system]
requires = [ "hatchling",]
build-backend = "hatchling.build"
[project]
name = "arcade_clickhouse"
version = "0.3.0"
description = "Tools to query and explore a ClickHouse database"
requires-python = ">=3.10"
dependencies = [
"arcade-mcp-server>=1.17.0,<2.0.0",
"clickhouse-connect>=0.7.0",
"pydantic>=2.11.7",
"sqlalchemy>=2.0.41",
"clickhouse-sqlalchemy>=0.2.0",
"greenlet>=3.2.3",
"aiochsa>=0.1.0",
"setuptools>=80.9.0",
]
[[project.authors]]
name = "evantahler"
email = "support@arcade.dev"
[project.optional-dependencies]
dev = [
"arcade-mcp[all]>=1.2.0,<2.0.0",
"pytest>=8.3.0,<8.4.0",
"pytest-cov>=4.0.0,<4.1.0",
"pytest-mock>=3.11.1,<3.12.0",
"pytest-asyncio>=0.24.0,<0.25.0",
"mypy>=1.5.1,<1.6.0",
"pre-commit>=3.4.0,<3.5.0",
"tox>=4.11.1,<4.12.0",
"ruff>=0.7.4,<0.8.0",
]
[project.scripts]
arcade-clickhouse = "arcade_clickhouse.__main__:main"
arcade_clickhouse = "arcade_clickhouse.__main__:main"
# Use local path sources for arcade libs when working locally
[tool.uv.sources]
arcade-mcp = { path = "../../", editable = true }
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
[tool.mypy]
files = [ "arcade_clickhouse/**/*.py",]
python_version = "3.10"
disallow_untyped_defs = "True"
disallow_any_unimported = "True"
no_implicit_optional = "True"
check_untyped_defs = "True"
warn_return_any = "True"
warn_unused_ignores = "True"
show_error_codes = "True"
ignore_missing_imports = "True"
[tool.pytest.ini_options]
testpaths = [ "tests",]
asyncio_default_fixture_loop_scope = "function"
[tool.coverage.report]
skip_empty = true
[tool.hatch.build.targets.wheel]
packages = [ "arcade_clickhouse",]

View file

@ -1,369 +0,0 @@
-- ClickHouse test database setup
-- This file contains sample data for testing the ClickHouse toolkit
-- Create users table
CREATE TABLE IF NOT EXISTS default.users (
id UInt32,
name String,
email String,
password_hash String,
created_at DateTime,
updated_at DateTime,
status String
) ENGINE = MergeTree()
ORDER BY (id, created_at);
-- Create messages table
CREATE TABLE IF NOT EXISTS default.messages (
id UInt32,
body String,
user_id UInt32,
created_at DateTime,
updated_at DateTime
) ENGINE = MergeTree()
ORDER BY (id, created_at);
-- Insert sample data into users table
INSERT INTO default.users (
id,
name,
email,
password_hash,
created_at,
updated_at,
status
)
VALUES (
1,
'Alice',
'alice@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
'2024-09-01 20:49:38',
'2024-09-02 03:49:39',
'active'
),
(
2,
'Bob',
'bob@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
'2024-09-02 17:49:23',
'2024-09-02 17:49:23',
'active'
),
(
3,
'Charlie',
'charlie@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
'2024-09-03 10:30:15',
'2024-09-03 10:30:15',
'active'
),
(
4,
'Diana',
'diana@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123',
'2024-09-04 14:20:30',
'2024-09-04 14:20:30',
'active'
),
(
5,
'Evan',
'evan@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456',
'2024-09-05 09:15:45',
'2024-09-05 09:15:45',
'active'
),
(
6,
'Fiona',
'fiona@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789',
'2024-09-06 16:45:12',
'2024-09-06 16:45:12',
'active'
),
(
7,
'George',
'george@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012',
'2024-09-07 11:30:25',
'2024-09-07 11:30:25',
'active'
),
(
8,
'Helen',
'helen@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345',
'2024-09-08 13:25:40',
'2024-09-08 13:25:40',
'active'
),
(
9,
'Ian',
'ian@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678',
'2024-09-09 08:40:55',
'2024-09-09 08:40:55',
'active'
),
(
10,
'Julia',
'julia@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901',
'2024-09-10 15:55:18',
'2024-09-10 15:55:18',
'active'
);
-- Insert sample data into messages table
INSERT INTO default.messages (id, body, user_id, created_at, updated_at)
VALUES (
1,
'Hello everyone!',
1,
'2025-01-10 10:00:00',
'2025-01-10 10:00:00'
),
(
2,
'How is everyone doing today?',
1,
'2025-01-10 11:30:00',
'2025-01-10 11:30:00'
),
(
3,
'Great to see you all here!',
1,
'2025-01-10 14:15:00',
'2025-01-10 14:15:00'
),
(
4,
'Hi Alice! Doing well, thanks for asking.',
2,
'2025-01-10 11:35:00',
'2025-01-10 11:35:00'
),
(
5,
'Anyone up for a game later?',
2,
'2025-01-10 16:20:00',
'2025-01-10 16:20:00'
),
(
6,
'Count me in for the game!',
3,
'2025-01-10 16:25:00',
'2025-01-10 16:25:00'
),
(
7,
'What time works for everyone?',
3,
'2025-01-10 16:30:00',
'2025-01-10 16:30:00'
),
(
8,
'I can play around 8 PM',
3,
'2025-01-10 17:00:00',
'2025-01-10 17:00:00'
),
(
9,
'8 PM works for me too!',
4,
'2025-01-10 17:05:00',
'2025-01-10 17:05:00'
),
(
10,
'What game should we play?',
4,
'2025-01-10 17:10:00',
'2025-01-10 17:10:00'
),
(
11,
'I suggest we try the new arcade game!',
5,
'2025-01-10 17:15:00',
'2025-01-10 17:15:00'
),
(
12,
'It has great multiplayer features',
5,
'2025-01-10 17:20:00',
'2025-01-10 17:20:00'
),
(
13,
'Perfect timing for a weekend session',
5,
'2025-01-10 18:00:00',
'2025-01-10 18:00:00'
),
(
26,
'Just finished setting up the game server!',
5,
'2025-01-10 20:00:00',
'2025-01-10 20:00:00'
),
(
27,
'Everyone should be able to connect now',
5,
'2025-01-10 20:05:00',
'2025-01-10 20:05:00'
),
(
28,
'I added some custom maps too',
5,
'2025-01-10 20:10:00',
'2025-01-10 20:10:00'
),
(
29,
'The graphics look amazing on this new version',
5,
'2025-01-10 20:15:00',
'2025-01-10 20:15:00'
),
(
30,
'Hope you all enjoy the new features',
5,
'2025-01-10 20:20:00',
'2025-01-10 20:20:00'
),
(
31,
'I also set up a leaderboard system',
5,
'2025-01-10 20:25:00',
'2025-01-10 20:25:00'
),
(
32,
'We can track high scores now',
5,
'2025-01-10 20:30:00',
'2025-01-10 20:30:00'
),
(
33,
'The game supports up to 8 players simultaneously',
5,
'2025-01-10 20:35:00',
'2025-01-10 20:35:00'
),
(
34,
'I tested it earlier and it runs smoothly',
5,
'2025-01-10 20:40:00',
'2025-01-10 20:40:00'
),
(
35,
'Cannot wait to see everyone online tonight!',
5,
'2025-01-10 20:45:00',
'2025-01-10 20:45:00'
),
(
14,
'Sounds like fun! I love arcade games.',
6,
'2025-01-10 18:05:00',
'2025-01-10 18:05:00'
),
(
15,
'Should I bring snacks?',
6,
'2025-01-10 18:10:00',
'2025-01-10 18:10:00'
),
(
16,
'Snacks are always welcome!',
7,
'2025-01-10 18:15:00',
'2025-01-10 18:15:00'
),
(
17,
'I can bring some drinks',
7,
'2025-01-10 18:20:00',
'2025-01-10 18:20:00'
),
(
18,
'This is going to be awesome',
7,
'2025-01-10 19:00:00',
'2025-01-10 19:00:00'
),
(
19,
'I agree! Cannot wait for the game night.',
8,
'2025-01-10 19:05:00',
'2025-01-10 19:05:00'
),
(
20,
'Should we set up a Discord call?',
8,
'2025-01-10 19:10:00',
'2025-01-10 19:10:00'
),
(
21,
'Discord would be perfect for voice chat',
9,
'2025-01-10 19:15:00',
'2025-01-10 19:15:00'
),
(
22,
'I will create a server for us',
9,
'2025-01-10 19:20:00',
'2025-01-10 19:20:00'
),
(
23,
'Link will be shared in a few minutes',
9,
'2025-01-10 19:25:00',
'2025-01-10 19:25:00'
),
(
24,
'Thanks Ian! You are the best.',
10,
'2025-01-10 19:30:00',
'2025-01-10 19:30:00'
),
(
25,
'See you all at 8 PM!',
10,
'2025-01-10 19:35:00',
'2025-01-10 19:35:00'
);

View file

@ -1,195 +0,0 @@
import os
from os import environ
from unittest.mock import MagicMock
import pytest
import pytest_asyncio
from arcade_clickhouse.tools.clickhouse import (
DatabaseEngine,
discover_schemas,
discover_tables,
execute_select_query,
get_table_schema,
)
from arcade_mcp_server import Context
from arcade_mcp_server.exceptions import RetryableToolError
CLICKHOUSE_DATABASE_CONNECTION_STRING = (
environ.get("TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING")
or "clickhouse+native://localhost:9000/default"
)
@pytest.fixture
def mock_context():
context = MagicMock(spec=Context)
context.get_secret = MagicMock(return_value=CLICKHOUSE_DATABASE_CONNECTION_STRING)
return context
# before the tests, restore the database from the dump
@pytest_asyncio.fixture(autouse=True)
async def restore_database():
import clickhouse_connect
# Create client for database setup
client = clickhouse_connect.get_client(host="localhost", port=8123)
# Clear existing tables first to avoid duplicates
client.command("DROP TABLE IF EXISTS default.messages")
client.command("DROP TABLE IF EXISTS default.users")
# Read and execute the dump file
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
queries = f.read().split(";")
for query in queries:
if query.strip():
client.command(query)
client.close()
@pytest_asyncio.fixture(autouse=True)
async def cleanup_engines():
"""Clean up database engines after each test to prevent connection leaks."""
yield
# Clean up all cached engines after each test
await DatabaseEngine.cleanup()
@pytest.mark.asyncio
async def test_discover_schemas(mock_context) -> None:
assert await discover_schemas(mock_context) == ["default"]
@pytest.mark.asyncio
async def test_discover_tables(mock_context) -> None:
tables = await discover_tables(mock_context)
assert sorted(tables) == ["messages", "users"]
@pytest.mark.asyncio
async def test_get_table_schema(mock_context) -> None:
users_schema = await get_table_schema(mock_context, "default", "users")
expected_users = [
"id: UInt32 (PRIMARY KEY)",
"name: String",
"email: String",
"password_hash: String",
"created_at: DateTime (PRIMARY KEY)",
"updated_at: DateTime",
"status: String",
]
assert users_schema == expected_users
messages_schema = await get_table_schema(mock_context, "default", "messages")
expected_messages = [
"id: UInt32 (PRIMARY KEY)",
"body: String",
"user_id: UInt32",
"created_at: DateTime (PRIMARY KEY)",
"updated_at: DateTime",
]
assert messages_schema == expected_messages
@pytest.mark.asyncio
async def test_execute_select_query(mock_context) -> None:
# Test specific user query with limit
result1 = await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 1",
limit=1,
)
assert result1 == ["(1, 'Alice', 'alice@example.com')"]
# Test query with offset
result2 = await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
order_by_clause="id",
limit=1,
offset=1,
)
assert result2 == ["(2, 'Bob', 'bob@example.com')"]
@pytest.mark.asyncio
async def test_execute_select_query_with_keywords(mock_context) -> None:
result = await execute_select_query(
mock_context,
select_clause="SELECT id, name, email",
from_clause="FROM users",
limit=1,
)
assert result == ["(1, 'Alice', 'alice@example.com')"]
@pytest.mark.asyncio
async def test_execute_select_query_with_join(mock_context) -> None:
result = await execute_select_query(
mock_context,
select_clause="u.id, u.name, u.email, m.id, m.body",
from_clause="users u",
join_clause="messages m ON u.id = m.user_id",
limit=1,
)
assert result == ["(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')"]
@pytest.mark.asyncio
async def test_execute_select_query_with_group_by(mock_context) -> None:
result = await execute_select_query(
mock_context,
select_clause="u.name, COUNT(m.id) AS message_count",
from_clause="messages m",
join_clause="users u ON m.user_id = u.id",
group_by_clause="u.name",
order_by_clause="message_count DESC",
limit=2,
)
assert result == [
"('Evan', 13)",
"('Alice', 3)",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_no_results(mock_context) -> None:
# does not raise an error
assert (
await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 9999999999",
)
== []
)
@pytest.mark.asyncio
async def test_execute_select_query_with_problem(mock_context) -> None:
# 'foo' is not a valid id
with pytest.raises(RetryableToolError) as e:
await execute_select_query(
mock_context,
select_clause="*",
from_clause="users",
where_clause="id = 'foo'",
)
assert "Do not use * in the select clause" in str(e.value)
@pytest.mark.asyncio
async def test_execute_select_query_rejects_non_select(mock_context) -> None:
with pytest.raises(RetryableToolError) as e:
await execute_select_query(
mock_context,
select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
from_clause="users",
)
assert "Only SELECT queries are allowed" in str(e.value)

View file

@ -1,3 +0,0 @@
#!/bin/bash
docker run -d --name some-clickhouse-server --ulimit nofile=262144:262144 -p 8123:8123 -p 8443:8443 -p 9000:9000 yandex/clickhouse-server

View file

@ -1,18 +0,0 @@
files: ^.*/linkedin/.*
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.4.0"
hooks:
- id: check-case-conflict
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.7
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View file

@ -1,46 +0,0 @@
target-version = "py310"
line-length = 100
fix = true
[lint]
select = [
# flake8-2020
"YTT",
# flake8-bandit
"S",
# flake8-bugbear
"B",
# flake8-builtins
"A",
# flake8-comprehensions
"C4",
# flake8-debugger
"T10",
# flake8-simplify
"SIM",
# isort
"I",
# mccabe
"C90",
# pycodestyle
"E", "W",
# pyflakes
"F",
# pygrep-hooks
"PGH",
# pyupgrade
"UP",
# ruff
"RUF",
# tryceratops
"TRY",
]
[lint.per-file-ignores]
"*" = ["TRY003", "B904"]
"**/tests/*" = ["S101", "E501"]
"**/evals/*" = ["S101", "E501"]
[format]
preview = true
skip-magic-trailing-comma = false

View file

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2025, Arcade AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,55 +0,0 @@
.PHONY: help
help:
@echo "🛠️ github Commands:\n"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: install
install: ## Install the uv environment and install all packages with dependencies
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras --no-sources
@if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi
@echo "✅ All packages and dependencies installed via uv"
.PHONY: install-local
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras
@if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi
@echo "✅ All packages and dependencies installed via uv"
.PHONY: build
build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
uv build
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning dist directory"
rm -rf dist
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
@uv run --no-sources coverage report
@echo "Generating coverage report"
@uv run --no-sources coverage html
.PHONY: bump-version
bump-version: ## Bump the version in the pyproject.toml file by a patch version
@echo "🚀 Bumping version in pyproject.toml"
uv version --no-sources --bump patch
.PHONY: check
check: ## Run code quality tools.
@if [ -f .pre-commit-config.yaml ]; then\
echo "🚀 Linting code: Running pre-commit";\
uv run --no-sources pre-commit run -a;\
fi
@echo "🚀 Static type checking: Running mypy"
@uv run --no-sources mypy --config-file=pyproject.toml

View file

@ -1,28 +0,0 @@
import sys
from typing import cast
from arcade_mcp_server import MCPApp
from arcade_mcp_server.mcp_app import TransportType
import arcade_linkedin
app = MCPApp(
name="LinkedIn",
instructions=(
"Use this server when you need to interact with LinkedIn to help users "
"create and share posts on their LinkedIn profile."
),
)
app.add_tools_from_module(arcade_linkedin)
def main() -> None:
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
app.run(transport=cast(TransportType, transport), host=host, port=port)
if __name__ == "__main__":
main()

View file

@ -1 +0,0 @@
LINKEDIN_BASE_URL = "https://api.linkedin.com/v2"

View file

@ -1,76 +0,0 @@
from typing import Annotated
from arcade_mcp_server import Context, tool
from arcade_mcp_server.auth import LinkedIn
from arcade_mcp_server.exceptions import ToolExecutionError
from arcade_mcp_server.metadata import (
Behavior,
Classification,
Operation,
ServiceDomain,
ToolMetadata,
)
from arcade_linkedin.tools.utils import _handle_linkedin_api_error, _send_linkedin_request
@tool(
requires_auth=LinkedIn(
scopes=["w_member_social"],
),
metadata=ToolMetadata(
classification=Classification(
service_domains=[ServiceDomain.SOCIAL_MEDIA],
),
behavior=Behavior(
operations=[Operation.CREATE],
read_only=False,
destructive=False,
idempotent=False,
open_world=True,
),
),
)
async def create_text_post(
context: Context,
text: Annotated[str, "The text content of the post"],
) -> Annotated[str, "URL of the shared post"]:
"""Share a new text post to LinkedIn."""
endpoint = "/ugcPosts"
# The LinkedIn user ID is required to create a post, even though we're using
# the user's access token.
# Arcade Engine gets the current user's info from LinkedIn and automatically
# populates context.authorization.user_info.
# LinkedIn calls the user ID "sub" in their user_info data payload. See:
# https://learn.microsoft.com/en-us/linkedin/consumer/integrations/self-serve/sign-in-with-linkedin-v2#api-request-to-retreive-member-details
user_id = context.authorization.user_info.get("sub") if context.authorization else None
if not user_id:
raise ToolExecutionError(
"User ID not found.",
developer_message="User ID not found in `context.authorization.user_info.sub`",
)
author_id = f"urn:li:person:{user_id}"
payload = {
"author": author_id,
"lifecycleState": "PUBLISHED",
"specificContent": {
"com.linkedin.ugc.ShareContent": {
"shareCommentary": {"text": text},
"shareMediaCategory": "NONE",
}
},
"visibility": {"com.linkedin.ugc.MemberNetworkVisibility": "PUBLIC"},
}
response = await _send_linkedin_request(context, "POST", endpoint, json_data=payload)
if response.status_code >= 200 and response.status_code < 300:
share_id = response.json().get("id")
return f"https://www.linkedin.com/feed/update/{share_id}/"
_handle_linkedin_api_error(response)
return ""

View file

@ -1,68 +0,0 @@
import httpx
from arcade_mcp_server import Context
from arcade_mcp_server.exceptions import ToolExecutionError
from arcade_linkedin.tools.constants import LINKEDIN_BASE_URL
async def _send_linkedin_request(
context: Context,
method: str,
endpoint: str,
params: dict | None = None,
json_data: dict | None = None,
) -> httpx.Response:
"""
Send an asynchronous request to the LinkedIn API.
Args:
context: The tool context containing the authorization token.
method: The HTTP method (GET, POST, PUT, DELETE, etc.).
endpoint: The API endpoint path (e.g., "/ugcPosts").
params: Query parameters to include in the request.
json_data: JSON data to include in the request body.
Returns:
The response object from the API request.
Raises:
ToolExecutionError: If the request fails for any reason.
"""
url = f"{LINKEDIN_BASE_URL}{endpoint}"
token = (
context.authorization.token if context.authorization and context.authorization.token else ""
)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method, url, headers=headers, params=params, json=json_data
)
response.raise_for_status()
except httpx.RequestError as e:
raise ToolExecutionError(f"Failed to send request to LinkedIn API: {e}")
return response
def _handle_linkedin_api_error(response: httpx.Response) -> None:
"""
Handle errors from the LinkedIn API by mapping common status codes to ToolExecutionErrors.
Args:
response: The response object from the API request.
Raises:
ToolExecutionError: If the response contains an error status code.
"""
status_code_map = {
401: ToolExecutionError("Unauthorized: Invalid or expired token"),
403: ToolExecutionError("Forbidden: User does not have Spotify Premium"),
429: ToolExecutionError("Too Many Requests: Rate limit exceeded"),
}
if response.status_code in status_code_map:
raise status_code_map[response.status_code]
elif response.status_code >= 400:
raise ToolExecutionError(f"Error: {response.status_code} - {response.text}")

View file

@ -1,24 +0,0 @@
from unittest.mock import MagicMock
import pytest
from arcade_mcp_server import Context
@pytest.fixture
def tool_context():
"""Fixture for the tool Context with mock authorization."""
context = MagicMock(spec=Context)
authorization = MagicMock()
authorization.token = "test_token" # noqa: S105
authorization.user_info = {"sub": "test_user"}
context.authorization = authorization
return context
@pytest.fixture
def mock_httpx_client(mocker):
"""Fixture to mock the httpx.AsyncClient."""
# Mock the AsyncClient context manager
mock_client = mocker.patch("httpx.AsyncClient", autospec=True)
async_mock_client = mock_client.return_value.__aenter__.return_value
return async_mock_client

View file

@ -1,48 +0,0 @@
from arcade_core import ToolCatalog
from arcade_evals import (
EvalRubric,
EvalSuite,
ExpectedToolCall,
SimilarityCritic,
tool_eval,
)
import arcade_linkedin
from arcade_linkedin.tools.share import create_text_post
rubric = EvalRubric(
fail_threshold=0.85,
warn_threshold=0.95,
)
catalog = ToolCatalog()
catalog.add_module(arcade_linkedin)
@tool_eval()
def linkedin_eval_suite() -> EvalSuite:
suite = EvalSuite(
name="LinkedIn Tools Evaluation",
system_message="You are an AI assistant with access to LinkedIn tools. Use them to help the user with their tasks.",
catalog=catalog,
rubric=rubric,
)
suite.add_case(
name="Run code",
user_message="post this transcription to linkedin. there may be some things that you need to clean up since it was spoken.: 'It is with great pleasure that I announce that I am now a member of the LinkedIn community! I'd like to thank the LinkedIn team for their support and encouragement in my journey to success. hash tag Y2K'",
expected_tool_calls=[
ExpectedToolCall(
func=create_text_post,
args={
"text": "It is with great pleasure that I announce that I am now a member of the LinkedIn community! I'd like to thank the LinkedIn team for their support and encouragement in my journey to success. #Y2K",
},
)
],
critics=[
SimilarityCritic(critic_field="text", weight=1.0),
],
)
return suite

View file

@ -1,59 +0,0 @@
[build-system]
requires = [ "hatchling",]
build-backend = "hatchling.build"
[project]
name = "arcade_linkedin"
version = "0.3.0"
description = "Arcade.dev LLM tools for LinkedIn"
requires-python = ">=3.10"
dependencies = [
"arcade-mcp-server>=1.17.0,<2.0.0",
"httpx>=0.27.2,<1.0.0",
]
[[project.authors]]
name = "Arcade"
email = "dev@arcade.dev"
[project.scripts]
arcade-linkedin = "arcade_linkedin.__main__:main"
arcade_linkedin = "arcade_linkedin.__main__:main"
[project.optional-dependencies]
dev = [
"arcade-mcp[all]>=1.2.0,<2.0.0",
"pytest>=8.3.0,<8.4.0",
"pytest-cov>=4.0.0,<4.1.0",
"pytest-asyncio>=0.24.0,<0.25.0",
"pytest-mock>=3.11.1,<3.12.0",
"mypy>=1.5.1,<1.6.0",
"pre-commit>=3.4.0,<3.5.0",
"tox>=4.11.1,<4.12.0",
"ruff>=0.7.4,<0.8.0",
]
# Use local path sources for arcade libs when working locally
[tool.uv.sources]
arcade-mcp = {path = "../../", editable = true}
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
[tool.mypy]
files = [ "arcade_linkedin/**/*.py",]
python_version = "3.10"
disallow_untyped_defs = "True"
disallow_any_unimported = "True"
no_implicit_optional = "True"
check_untyped_defs = "True"
warn_return_any = "True"
warn_unused_ignores = "True"
show_error_codes = "True"
ignore_missing_imports = "True"
[tool.pytest.ini_options]
testpaths = [ "tests",]
[tool.coverage.report]
skip_empty = true
[tool.hatch.build.targets.wheel]
packages = [ "arcade_linkedin",]

View file

@ -1,35 +0,0 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from arcade_mcp_server.exceptions import ToolExecutionError
from arcade_linkedin.tools.share import create_text_post
@pytest.mark.asyncio
async def test_create_text_post_success(tool_context, mock_httpx_client):
"""Test successful creation of a LinkedIn text post."""
# Mock response for a successful post creation
mock_response = MagicMock()
mock_response.status_code = 201
mock_response.json.return_value = {"id": "1234567890"}
# Ensure the mock is awaited properly
mock_httpx_client.request = AsyncMock(return_value=mock_response)
post_text = "Hello, LinkedIn!"
result = await create_text_post(tool_context, post_text)
expected_url = "https://www.linkedin.com/feed/update/1234567890/"
assert result == expected_url
mock_httpx_client.request.assert_called_once()
@pytest.mark.asyncio
async def test_create_text_post_no_user_id(tool_context):
"""Test error when user ID is not found in the context."""
# Simulate missing user ID in the context
tool_context.authorization.user_info = {}
post_text = "Hello, LinkedIn!"
with pytest.raises(ToolExecutionError, match="User ID not found"):
await create_text_post(tool_context, post_text)

View file

@ -1,18 +0,0 @@
files: ^.*/math/.*
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.4.0"
hooks:
- id: check-case-conflict
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.7
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View file

@ -1,47 +0,0 @@
target-version = "py310"
line-length = 100
fix = true
[lint]
select = [
# flake8-2020
"YTT",
# flake8-bandit
"S",
# flake8-bugbear
"B",
# flake8-builtins
"A",
# flake8-comprehensions
"C4",
# flake8-debugger
"T10",
# flake8-simplify
"SIM",
# isort
"I",
# mccabe
"C90",
# pycodestyle
"E", "W",
# pyflakes
"F",
# pygrep-hooks
"PGH",
# pyupgrade
"UP",
# ruff
"RUF",
# tryceratops
"TRY",
]
[lint.per-file-ignores]
"*" = ["TRY003", "B904"]
"**/tests/*" = ["S101", "E501"]
"**/evals/*" = ["S101", "E501"]
[format]
preview = true
skip-magic-trailing-comma = false

View file

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2025, Arcade AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,55 +0,0 @@
.PHONY: help
help:
@echo "🛠️ github Commands:\n"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: install
install: ## Install the uv environment and install all packages with dependencies
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras --no-sources
@if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi
@echo "✅ All packages and dependencies installed via uv"
.PHONY: install-local
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras
@if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi
@echo "✅ All packages and dependencies installed via uv"
.PHONY: build
build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
uv build
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning dist directory"
rm -rf dist
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
@uv run --no-sources coverage report
@echo "Generating coverage report"
@uv run --no-sources coverage html
.PHONY: bump-version
bump-version: ## Bump the version in the pyproject.toml file by a patch version
@echo "🚀 Bumping version in pyproject.toml"
uv version --no-sources --bump patch
.PHONY: check
check: ## Run code quality tools.
@if [ -f .pre-commit-config.yaml ]; then\
echo "🚀 Linting code: Running pre-commit";\
uv run --no-sources pre-commit run -a;\
fi
@echo "🚀 Static type checking: Running mypy"
@uv run --no-sources mypy --config-file=pyproject.toml

View file

@ -1,29 +0,0 @@
import sys
from typing import cast
from arcade_mcp_server import MCPApp
from arcade_mcp_server.mcp_app import TransportType
import arcade_math
app = MCPApp(
name="Math",
instructions=(
"Use this server when you need to perform mathematical calculations to help users "
"with arithmetic, trigonometry, statistics, exponents, rounding, and other math operations."
),
)
app.add_tools_from_module(arcade_math)
def main() -> None:
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
app.run(transport=cast(TransportType, transport), host=host, port=port)
if __name__ == "__main__":
main()

View file

@ -1,65 +0,0 @@
from arcade_math.tools.arithmetic import (
add,
divide,
mod,
multiply,
subtract,
sum_list,
sum_range,
)
from arcade_math.tools.exponents import (
log,
power,
)
from arcade_math.tools.miscellaneous import (
abs_val,
factorial,
sqrt,
)
from arcade_math.tools.random import (
generate_random_float,
generate_random_int,
)
from arcade_math.tools.rational import (
gcd,
lcm,
)
from arcade_math.tools.rounding import (
ceil,
floor,
round_num,
)
from arcade_math.tools.statistics import (
avg,
median,
)
from arcade_math.tools.trigonometry import (
deg_to_rad,
rad_to_deg,
)
__all__ = [
"abs_val",
"add",
"avg",
"ceil",
"deg_to_rad",
"divide",
"factorial",
"floor",
"gcd",
"generate_random_float",
"generate_random_int",
"lcm",
"log",
"median",
"mod",
"multiply",
"power",
"rad_to_deg",
"round_num",
"sqrt",
"subtract",
"sum_list",
"sum_range",
]

View file

@ -1,161 +0,0 @@
import decimal
from decimal import Decimal
from typing import Annotated
from arcade_mcp_server import tool
from arcade_mcp_server.metadata import Behavior, ToolMetadata
decimal.getcontext().prec = 100
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def add(
a: Annotated[str, "The first number as a string"],
b: Annotated[str, "The second number as a string"],
) -> Annotated[str, "The sum of the two numbers as a string"]:
"""
Add two numbers together
"""
# Use Decimal for arbitrary precision
a_decimal = Decimal(a)
b_decimal = Decimal(b)
return str(a_decimal + b_decimal)
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def subtract(
a: Annotated[str, "The first number as a string"],
b: Annotated[str, "The second number as a string"],
) -> Annotated[str, "The difference of the two numbers as a string"]:
"""
Subtract two numbers
"""
# Use Decimal for arbitrary precision
a_decimal = Decimal(a)
b_decimal = Decimal(b)
return str(a_decimal - b_decimal)
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def multiply(
a: Annotated[str, "The first number as a string"],
b: Annotated[str, "The second number as a string"],
) -> Annotated[str, "The product of the two numbers as a string"]:
"""
Multiply two numbers together
"""
# Use Decimal for arbitrary precision
a_decimal = Decimal(a)
b_decimal = Decimal(b)
return str(a_decimal * b_decimal)
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def divide(
a: Annotated[str, "The first number as a string"],
b: Annotated[str, "The second number as a string"],
) -> Annotated[str, "The quotient of the two numbers as a string"]:
"""
Divide two numbers
"""
# Use Decimal for arbitrary precision
a_decimal = Decimal(a)
b_decimal = Decimal(b)
return str(a_decimal / b_decimal)
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def sum_list(
numbers: Annotated[list[str], "The list of numbers as strings"],
) -> Annotated[str, "The sum of the numbers in the list as a string"]:
"""
Sum all numbers in a list
"""
# Use Decimal for arbitrary precision
return str(sum([Decimal(n) for n in numbers]))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def sum_range(
start: Annotated[str, "The start of the range to sum as a string"],
end: Annotated[str, "The end of the range to sum as a string"],
) -> Annotated[str, "The sum of the numbers in the list as a string"]:
"""
Sum all numbers from start through end
"""
return str(sum(list(range(int(start), int(end) + 1))))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def mod(
a: Annotated[str, "The dividend as a string"],
b: Annotated[str, "The divisor as a string"],
) -> Annotated[str, "The remainder after dividing a by b as a string"]:
"""
Calculate the remainder (modulus) of one number divided by another
"""
# Use Decimal for arbitrary precision
return str(Decimal(a) % Decimal(b))

View file

@ -1,51 +0,0 @@
import decimal
import math
from decimal import Decimal
from typing import Annotated
from arcade_mcp_server import tool
from arcade_mcp_server.metadata import Behavior, ToolMetadata
decimal.getcontext().prec = 100
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def log(
a: Annotated[str, "The number to take the logarithm of as a string"],
base: Annotated[str, "The logarithmic base as a string"],
) -> Annotated[str, "The logarithm of the number with the specified base as a string"]:
"""
Calculate the logarithm of a number with a given base
"""
# Use Decimal for arbitrary precision
return str(math.log(Decimal(a), Decimal(base)))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def power(
a: Annotated[str, "The base number as a string"],
b: Annotated[str, "The exponent as a string"],
) -> Annotated[str, "The result of raising a to the power of b as a string"]:
"""
Calculate one number raised to the power of another
"""
# Use Decimal for arbitrary precision
return str(Decimal(a) ** Decimal(b))

View file

@ -1,70 +0,0 @@
import decimal
import math
from decimal import Decimal
from typing import Annotated
from arcade_mcp_server import tool
from arcade_mcp_server.metadata import Behavior, ToolMetadata
decimal.getcontext().prec = 100
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def abs_val(
a: Annotated[str, "The number as a string"],
) -> Annotated[str, "The absolute value of the number as a string"]:
"""
Calculate the absolute value of a number
"""
# Use Decimal for arbitrary precision
return str(abs(Decimal(a)))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def factorial(
a: Annotated[str, "The non-negative integer to compute the factorial for as a string"],
) -> Annotated[str, "The factorial of the number as a string"]:
"""
Compute the factorial of a non-negative integer
Returns "1" for "0"
"""
return str(math.factorial(int(a)))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def sqrt(
a: Annotated[str, "The number to square root as a string"],
) -> Annotated[str, "The square root of the number as a string"]:
"""
Get the square root of a number
"""
# Use Decimal for arbitrary precision
a_decimal = Decimal(a)
return str(a_decimal.sqrt())

View file

@ -1,57 +0,0 @@
import random
from typing import Annotated
from arcade_mcp_server import tool
from arcade_mcp_server.metadata import Behavior, ToolMetadata
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=False,
open_world=False,
),
),
)
def generate_random_int(
min_value: Annotated[str, "The minimum value of the random integer as a string"],
max_value: Annotated[str, "The maximum value of the random integer as a string"],
seed: Annotated[
str | None,
"The seed for the random number generator as a string."
" If None, the current system time is used.",
] = None,
) -> Annotated[str, "A random integer between min_value and max_value as a string"]:
"""Generate a random integer between min_value and max_value (inclusive)."""
if seed is not None:
random.seed(int(seed))
return str(random.randint(int(min_value), int(max_value))) # noqa: S311
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=False,
open_world=False,
),
),
)
def generate_random_float(
min_value: Annotated[str, "The minimum value of the random float as a string"],
max_value: Annotated[str, "The maximum value of the random float as a string"],
seed: Annotated[
str | None,
"The seed for the random number generator as a string."
" If None, the current system time is used.",
] = None,
) -> Annotated[str, "A random float between min_value and max_value as a string"]:
"""Generate a random float between min_value and max_value."""
if seed is not None:
random.seed(int(seed))
return str(random.uniform(float(min_value), float(max_value))) # noqa: S311

View file

@ -1,49 +0,0 @@
import math
from typing import Annotated
from arcade_mcp_server import tool
from arcade_mcp_server.metadata import Behavior, ToolMetadata
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def gcd(
a: Annotated[str, "First integer as a string"],
b: Annotated[str, "Second integer as a string"],
) -> Annotated[str, "The greatest common divisor of a and b as a string"]:
"""
Calculate the greatest common divisor (GCD) of two integers.
"""
return str(math.gcd(int(a), int(b)))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def lcm(
a: Annotated[str, "First integer as a string"],
b: Annotated[str, "Second integer as a string"],
) -> Annotated[str, "The least common multiple of a and b as a string"]:
"""
Calculate the least common multiple (LCM) of two integers.
Returns "0" if either integer is 0.
"""
a_int, b_int = int(a), int(b)
if a_int == 0 or b_int == 0:
return "0"
return str(abs(a_int * b_int) // math.gcd(a_int, b_int))

View file

@ -1,75 +0,0 @@
import decimal
import math
from decimal import Decimal
from typing import Annotated
from arcade_mcp_server import tool
from arcade_mcp_server.metadata import Behavior, ToolMetadata
decimal.getcontext().prec = 100
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def ceil(
a: Annotated[str, "The number to round up as a string"],
) -> Annotated[str, "The smallest integer greater than or equal to the number as a string"]:
"""
Return the ceiling of a number
"""
# Use Decimal for arbitrary precision
return str(math.ceil(Decimal(a)))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def floor(
a: Annotated[str, "The number to round down as a string"],
) -> Annotated[str, "The largest integer less than or equal to the number as a string"]:
"""
Return the floor of a number
"""
# Use Decimal for arbitrary precision
return str(math.floor(Decimal(a)))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def round_num(
value: Annotated[str, "The number to round as a string"],
ndigits: Annotated[str, "The number of digits after the decimal point as a string"],
) -> Annotated[str, "The number rounded to the specified number of digits as a string"]:
"""
Round a number to a specified number of positive digits
"""
ndigits_int = int(ndigits)
if ndigits_int >= 0:
# Use Decimal for arbitrary precision
return str(round(Decimal(value), int(ndigits_int)))
# cast value from str -> float -> int here because rounding with negative
# decimals is only useful for weird math
return str(round(int(float(value)), int(ndigits_int)))

View file

@ -1,53 +0,0 @@
import decimal
from decimal import Decimal
from statistics import median as stats_median
from typing import Annotated
from arcade_mcp_server import tool
from arcade_mcp_server.metadata import Behavior, ToolMetadata
decimal.getcontext().prec = 100
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def avg(
numbers: Annotated[list[str], "The list of numbers as strings"],
) -> Annotated[str, "The average (mean) of the numbers in the list as a string"]:
"""
Calculate the average (mean) of a list of numbers.
Returns "0.0" if the list is empty.
"""
# Use Decimal for arbitrary precision
d_numbers = [Decimal(n) for n in numbers]
return str(sum(d_numbers) / len(d_numbers)) if d_numbers else "0.0"
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def median(
numbers: Annotated[list[str], "A list of numbers as strings"],
) -> Annotated[str, "The median value of the numbers in the list as a string"]:
"""
Calculate the median of a list of numbers.
Returns "0.0" if the list is empty.
"""
# Use Decimal for arbitrary precision
d_numbers = [Decimal(n) for n in numbers]
return str(stats_median(d_numbers)) if d_numbers else "0.0"

View file

@ -1,49 +0,0 @@
import decimal
import math
from decimal import Decimal
from typing import Annotated
from arcade_mcp_server import tool
from arcade_mcp_server.metadata import Behavior, ToolMetadata
decimal.getcontext().prec = 100
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def deg_to_rad(
degrees: Annotated[str, "Angle in degrees as a string"],
) -> Annotated[str, "Angle in radians as a string"]:
"""
Convert an angle from degrees to radians.
"""
# Use Decimal for arbitrary precision
return str(math.radians(Decimal(degrees)))
@tool(
metadata=ToolMetadata(
behavior=Behavior(
read_only=True,
destructive=False,
idempotent=True,
open_world=False,
),
),
)
def rad_to_deg(
radians: Annotated[str, "Angle in radians as a string"],
) -> Annotated[str, "Angle in degrees as a string"]:
"""
Convert an angle from radians to degrees.
"""
# Use Decimal for arbitrary precision
return str(math.degrees(Decimal(radians)))

View file

@ -1,137 +0,0 @@
from collections.abc import Callable
from typing import Any
from arcade_core import ToolCatalog
from arcade_evals import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
tool_eval,
)
import arcade_math
from arcade_math.tools.arithmetic import (
add,
divide,
mod,
multiply,
subtract,
sum_list,
sum_range,
)
from arcade_math.tools.exponents import (
log,
power,
)
from arcade_math.tools.miscellaneous import (
abs_val,
factorial,
sqrt,
)
from arcade_math.tools.rational import (
gcd,
lcm,
)
from arcade_math.tools.rounding import (
ceil,
floor,
round_num,
)
from arcade_math.tools.statistics import (
avg,
median,
)
from arcade_math.tools.trigonometry import (
deg_to_rad,
rad_to_deg,
)
# Type alias for test case tuples: (function, prompt_template, params)
TestCase = tuple[Callable[..., Any], str, dict[str, Any]]
# Evaluation rubric
rubric = EvalRubric(
fail_threshold=0.85,
warn_threshold=0.95,
)
catalog = ToolCatalog()
catalog.add_module(arcade_math)
@tool_eval()
def math_eval_suite() -> EvalSuite:
suite = EvalSuite(
name="Math Tools Evaluation",
system_message="You're an AI assistant with access to math tools. Use them to help the user with their math-related tasks.",
catalog=catalog,
rubric=rubric,
)
list_param = ["1", "2", "3", "4", "5"]
funcs_to_expression_and_params: list[TestCase] = [
# unary
(sqrt, "What's the square root of {a}?", {"a": "25"}),
(abs_val, "What's the absolute value of {a}?", {"a": "-10"}),
(factorial, "What's the factorial of {a}?", {"a": "5"}),
(deg_to_rad, "Convert {degrees} from degrees to radians", {"degrees": "180"}),
(rad_to_deg, "Convert {radians} from radias to degrees", {"radians": "3.14"}),
(ceil, "Compute the ceiling of {a}", {"a": "3.14"}),
(floor, "Compute the floor of {a}", {"a": "3.14"}),
# binary
(add, "Add {a} and {b}", {"a": "12345", "b": "987654321"}),
(subtract, "Subtract {b} from {a}", {"a": "987654321", "b": "12345"}),
(multiply, "Multiply {a} and {b}", {"a": "12345", "b": "567890"}),
(divide, "What is {a} divided by {b}?", {"a": "1234123479", "b": "123"}),
(
sum_range,
"What's the sum of all numbers from {start} to {end}?",
{"start": "10", "end": "345"},
),
(mod, "What's the remainder of dividing {a} by {b}?", {"a": "234", "b": "17"}),
(power, "Raise {a} to the power of {b}", {"a": "2", "b": "8"}),
(log, "What's the logarithm of {a} with base {base}?", {"a": "8", "base": "2"}),
(
round_num,
"Round {value} to {ndigits} decimal places",
{"value": "12.23746234", "ndigits": "3"},
),
(gcd, "Find the greatest common divisor of {a} and {b}", {"a": "50", "b": "10"}),
(lcm, "FInd the least common multiple of {a} and {b}", {"a": "7", "b": "13"}),
# n-nary
(
sum_list,
f"Calculate the sum of these numbers: {' '.join(list_param)}",
{"numbers": list_param},
),
(
avg,
f"Find the average of these numbers: {' '.join(list_param)}",
{"numbers": list_param},
),
(
median,
f"Find the median of these numbers: {' '.join(list_param)}",
{"numbers": list_param},
),
]
for func, expression, params in funcs_to_expression_and_params:
parametrized_expression = expression.format(**params)
num_params = len(params)
suite.add_case(
name=parametrized_expression,
user_message=parametrized_expression,
expected_tool_calls=[
ExpectedToolCall(
func=func,
args=params,
)
],
rubric=rubric,
critics=[BinaryCritic(critic_field=param, weight=1.0 / num_params) for param in params],
)
return suite

View file

@ -1,58 +0,0 @@
[build-system]
requires = [ "hatchling",]
build-backend = "hatchling.build"
[project]
name = "arcade_math"
version = "1.2.0"
description = "Arcade.dev LLM tools for doing math"
requires-python = ">=3.10"
dependencies = [
"arcade-mcp-server>=1.17.0,<2.0.0",
]
[[project.authors]]
name = "Arcade"
email = "dev@arcade.dev"
[project.optional-dependencies]
dev = [
"arcade-mcp[all]>=1.2.0,<2.0.0",
"pytest>=8.3.0,<8.4.0",
"pytest-cov>=4.0.0,<4.1.0",
"pytest-asyncio>=0.24.0,<0.25.0",
"pytest-mock>=3.11.1,<3.12.0",
"mypy>=1.5.1,<1.6.0",
"pre-commit>=3.4.0,<3.5.0",
"tox>=4.11.1,<4.12.0",
"ruff>=0.7.4,<0.8.0",
]
[project.scripts]
arcade-math = "arcade_math.__main__:main"
arcade_math = "arcade_math.__main__:main"
# Use local path sources for arcade libs when working locally
[tool.uv.sources]
arcade-mcp = {path = "../../", editable = true}
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
[tool.mypy]
files = [ "arcade_math/**/*.py",]
python_version = "3.10"
disallow_untyped_defs = "True"
disallow_any_unimported = "True"
no_implicit_optional = "True"
check_untyped_defs = "True"
warn_return_any = "True"
warn_unused_ignores = "True"
show_error_codes = "True"
ignore_missing_imports = "True"
[tool.pytest.ini_options]
testpaths = [ "tests",]
[tool.coverage.report]
skip_empty = true
[tool.hatch.build.targets.wheel]
packages = [ "arcade_math",]

View file

@ -1,147 +0,0 @@
import pytest
from arcade_mcp_server.exceptions import ToolExecutionError
from arcade_math.tools.arithmetic import (
add,
divide,
mod,
multiply,
subtract,
sum_list,
sum_range,
)
@pytest.mark.parametrize(
"a, b, expected",
[
("1", "2", "3"),
("-1", "1", "0"),
("0.5", "10.9", "11.4"),
# Big ints
("12345678901234567890", "9876543210987654321", "22222222112222222211"),
# Big floats
(
"12345678901234567890.120",
"9876543210987654321.987",
"22222222112222222212.107",
),
],
)
def test_add(a, b, expected):
assert add(a, b) == expected
@pytest.mark.parametrize(
"a, b, expected",
[
("1", "2", "-1"),
("-1", "1", "-2"),
("0.5", "10.9", "-10.4"),
# Big ints
("12345678901234567890", "12323456679012345668", "22222222222222222"),
# Big floats
(
"12345678901234567890.120",
"12343557689113355768.9079",
"2121212121212121.2121",
),
],
)
def test_subtract(a, b, expected):
assert subtract(a, b) == expected
@pytest.mark.parametrize(
"a, b, expected",
[
("-1", "2", "-2"),
("-10", "0", "-0"),
("0.5", "10.9", "5.45"),
# Big ints
(
"12345678901234567890",
"18000000162000001474380013420000",
"222222222222222222222222222261233060226101083800000",
),
# Big floats
(
"12345678901234567890.120",
"12345678901234567890.120",
"152415787532388367504868162811315348393.614400",
),
],
)
def test_multiply(a, b, expected):
assert multiply(a, b) == expected
@pytest.mark.parametrize(
"a, b, expected",
[
("-1", "2", "-0.5"),
("-10", "1", "-10"),
(
"0.5",
"10.9",
"0.0458715596330275229357798165137614678899082568807339"
"4495412844036697247706422018348623853211009174312",
),
# Big ints
("152407406035740740602050", "12345678901234567890", "12345"),
# Big floats
(
"152407406035740740603531.400",
"12345678901234567890.120",
"12345",
),
],
)
def test_divide(a, b, expected):
assert divide(a, b) == expected
def text_zero_division():
with pytest.raises(ToolExecutionError):
divide("1", "0")
with pytest.raises(ToolExecutionError):
divide("1", "0.0")
with pytest.raises(ToolExecutionError):
divide("1", "0.000000")
def test_sum_list():
assert sum_list(["1", "2", "3", "4", "5", "6"]) == "21"
assert sum_list([]) == "0"
assert sum_list(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-21"
assert sum_list(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "2.1"
def test_sum_range():
assert sum_range("8", "2") == "0"
assert sum_range("-8", "2") == "-33"
assert sum_range("8", "-2") == "0"
assert sum_range("2", "3") == "5"
assert sum_range("0", "10") == "55"
with pytest.raises(ToolExecutionError):
sum_range("2", "0.5")
with pytest.raises(ToolExecutionError):
sum_range("-1", "0.5")
with pytest.raises(ToolExecutionError):
sum_range("2.", "0.5")
with pytest.raises(ToolExecutionError):
sum_range("-1", "0.5")
def test_mod():
assert mod("-1", "0.5") == "-0.0"
assert mod("-8", "2") == "-0"
assert mod("0", "10") == "0"
assert mod("2", "0.5") == "0.0"
assert mod("2", "3") == "2"
assert mod("2.", "-0.5") == "0.0"
assert mod("2.1234", "0.6") == "0.3234"
assert mod("2.1234", "1") == "0.1234"
assert mod("2.1234", "3") == "2.1234"
assert mod("8", "-2") == "0"
assert mod("8", "2") == "0"

View file

@ -1,41 +0,0 @@
import pytest
from arcade_mcp_server.exceptions import ToolExecutionError
from arcade_math.tools.exponents import (
log,
power,
)
def test_log():
assert log("8", "2") == "3.0"
assert log("2", "3") == "0.6309297535714574"
assert log("2", "0.5") == "-1.0"
with pytest.raises(ToolExecutionError):
log("-1", "0.5")
with pytest.raises(ToolExecutionError):
log("0", "10")
def test_power():
assert power("-8", "2") == "64"
assert power("0", "10") == "0"
assert (
power("2", "0.5") == "1.41421356237309504880168872420969807856"
"9671875376948073176679737990732478462107038850387534327641573"
)
assert power("2", "3") == "8"
assert (
power("2.", "-0.5") == "0.707106781186547524400844362104849039"
"2848359376884740365883398689953662392310535194251937671638207864"
)
assert (
power("2.1234", "0.6") == "1.571155202490495156807227174573016145"
"282682479346448636509576776014844055570115193494685328114403375"
)
assert power("2.1234", "1") == "2.1234"
assert power("2.1234", "3") == "9.574044440904"
assert power("8", "-2") == "0.015625"
assert power("8", "2") == "64"
with pytest.raises(ToolExecutionError):
power("-1", "0.5")

View file

@ -1,81 +0,0 @@
import pytest
from arcade_mcp_server.exceptions import ToolExecutionError
from arcade_math.tools.miscellaneous import (
abs_val,
factorial,
sqrt,
)
def test_abs_val():
assert abs_val("2") == "2"
assert abs_val("-1") == "1"
assert abs_val("-1.12341234") == "1.12341234"
def test_factorial():
assert factorial("1") == "1"
assert factorial("0") == "1"
assert factorial("-0") == "1"
assert factorial("23") == "25852016738884976640000"
assert factorial("24") == "620448401733239439360000"
assert factorial("10") == "3628800"
with pytest.raises(ToolExecutionError):
factorial("-1")
with pytest.raises(ToolExecutionError):
factorial("-10")
with pytest.raises(ToolExecutionError):
factorial("0.0000")
with pytest.raises(ToolExecutionError):
factorial("-0.0")
with pytest.raises(ToolExecutionError):
factorial("1.0")
with pytest.raises(ToolExecutionError):
factorial("-1.0")
with pytest.raises(ToolExecutionError):
factorial("23.0")
def test_sqrt():
assert sqrt("1") == "1"
assert sqrt("0") == "0"
assert sqrt("-0") == "-0"
assert (
sqrt("23") == "4.79583152331271954159743806416269391999670704190"
"4129346485309114448257235907464082492191446436918861"
)
assert (
sqrt("24") == "4.89897948556635619639456814941178278393189496131"
"3340256865385134501920754914630053079718866209280470"
)
assert (
sqrt("10") == "3.16227766016837933199889354443271853371955513932"
"5216826857504852792594438639238221344248108379300295"
)
assert sqrt("0.0") == "0.0"
assert sqrt("0.0000") == "0.00"
assert sqrt("-0.0") == "-0.0"
assert sqrt("1.0") == "1.0"
assert (
sqrt("3.14") == "1.772004514666935040199112509753631525073608516"
"162942966817771970290992972348902551472561151153909188"
)
assert (
sqrt("0.4") == "0.6324555320336758663997787088865437067439110278"
"650433653715009705585188877278476442688496216758600590"
)
assert (
sqrt("10.0") == "3.162277660168379331998893544432718533719555139"
"325216826857504852792594438639238221344248108379300295"
)
with pytest.raises(ToolExecutionError):
sqrt("-1")
with pytest.raises(ToolExecutionError):
sqrt("-10")
with pytest.raises(ToolExecutionError):
sqrt("-1.0")
with pytest.raises(ToolExecutionError):
sqrt("-1.3")
with pytest.raises(ToolExecutionError):
sqrt("-10.0")

View file

@ -1,31 +0,0 @@
import pytest
from arcade_mcp_server.exceptions import ToolExecutionError
from arcade_math.tools.rational import (
gcd,
lcm,
)
def test_gcd():
assert gcd("-15", "-5") == "5"
assert gcd("15", "0") == "15"
assert gcd("15", "-2") == "1"
assert gcd("15", "-0") == "15"
assert gcd("15", "5") == "5"
assert gcd("7", "13") == "1"
assert gcd("-13", "13") == "13"
with pytest.raises(ToolExecutionError):
gcd("15.0", "5.0")
def test_lcm():
assert lcm("-15", "-5") == "15"
assert lcm("15", "0") == "0"
assert lcm("15", "-2") == "30"
assert lcm("15", "-0") == "0"
assert lcm("15", "5") == "15"
assert lcm("7", "13") == "91"
assert lcm("-13", "13") == "13"
with pytest.raises(ToolExecutionError):
lcm("15.0", "5.0")

View file

@ -1,54 +0,0 @@
from arcade_math.tools.rounding import (
ceil,
floor,
round_num,
)
def test_ceil():
assert ceil("1") == "1"
assert ceil("-1") == "-1"
assert ceil("0") == "0"
assert ceil("-0") == "0"
assert ceil("0.0") == "0"
assert ceil("0.0000") == "0"
assert ceil("-0.0") == "0"
assert ceil("1.0") == "1"
assert ceil("-1.0") == "-1"
assert ceil("3.14") == "4"
assert ceil("0.4") == "1"
assert ceil("-1.3") == "-1"
def test_floor():
assert floor("1") == "1"
assert floor("-1") == "-1"
assert floor("0") == "0"
assert floor("-0") == "0"
assert floor("10") == "10"
assert floor("0.0") == "0"
assert floor("0.0000") == "0"
assert floor("-0.0") == "0"
assert floor("1.0") == "1"
assert floor("-1.0") == "-1"
assert floor("3.14") == "3"
assert floor("0.4") == "0"
assert floor("-1.3") == "-2"
def test_round_num():
# TODO(mateo): ok with scientific notatin? ok with negative round digits?
assert round_num("1.2345", "-2") == "0"
assert round_num("1.2345", "-1") == "0"
assert round_num("1.2345", "0") == "1"
assert round_num("1.2345", "1") == "1.2"
assert round_num("1.2345", "2") == "1.23"
assert round_num("1.2345", "3") == "1.234"
assert round_num("1.2345", "8") == "1.23450000"
assert round_num("1.654321", "-2") == "0"
assert round_num("1.654321", "-1") == "0"
assert round_num("1.654321", "0") == "2"
assert round_num("1.654321", "1") == "1.7"
assert round_num("1.654321", "2") == "1.65"
assert round_num("1.654321", "3") == "1.654"
assert round_num("1.654321", "8") == "1.65432100"

View file

@ -1,18 +0,0 @@
from arcade_math.tools.statistics import (
avg,
median,
)
def test_avg():
assert avg(["1", "2", "3", "4", "5", "6"]) == "3.5"
assert avg([]) == "0.0"
assert avg(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-3.5"
assert avg(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "0.35"
def test_median():
assert median(["1", "2", "3", "4", "5", "6"]) == "3.5"
assert median([]) == "0.0"
assert median(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-3.5"
assert median(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "0.3"

View file

@ -1,45 +0,0 @@
from arcade_math.tools.trigonometry import (
deg_to_rad,
rad_to_deg,
)
def test_deg_to_rad():
assert deg_to_rad("1") == "0.017453292519943295"
assert deg_to_rad("-1") == "-0.017453292519943295"
assert deg_to_rad("0") == "0.0"
assert deg_to_rad("-0") == "-0.0"
assert deg_to_rad("23") == "0.4014257279586958"
assert deg_to_rad("24") == "0.4188790204786391"
assert deg_to_rad("-10") == "-0.17453292519943295"
assert deg_to_rad("10") == "0.17453292519943295"
assert deg_to_rad("180") == "3.141592653589793"
assert deg_to_rad("0.0") == "0.0"
assert deg_to_rad("0.0000") == "0.0"
assert deg_to_rad("-0.0") == "-0.0"
assert deg_to_rad("1.0") == "0.017453292519943295"
assert deg_to_rad("-1.0") == "-0.017453292519943295"
assert deg_to_rad("23.0") == "0.4014257279586958"
assert deg_to_rad("0.4") == "0.006981317007977318"
assert deg_to_rad("-10.0") == "-0.17453292519943295"
assert deg_to_rad("10.0") == "0.17453292519943295"
def test_rad_to_deg():
assert rad_to_deg("1") == "57.29577951308232"
assert rad_to_deg("-1") == "-57.29577951308232"
assert rad_to_deg("0") == "0.0"
assert rad_to_deg("-0") == "-0.0"
assert rad_to_deg("23") == "1317.8029288008934"
assert rad_to_deg("24") == "1375.0987083139757"
assert rad_to_deg("-10") == "-572.9577951308232"
assert rad_to_deg("10") == "572.9577951308232"
assert rad_to_deg("0.0") == "0.0"
assert rad_to_deg("0.0000") == "0.0"
assert rad_to_deg("-0.0") == "-0.0"
assert rad_to_deg("1.0") == "57.29577951308232"
assert rad_to_deg("-1.0") == "-57.29577951308232"
assert rad_to_deg("3.14") == "179.9087476710785"
assert rad_to_deg("0.4") == "22.918311805232932"
assert rad_to_deg("-10.0") == "-572.9577951308232"
assert rad_to_deg("10.0") == "572.9577951308232"

View file

@ -1,53 +0,0 @@
.PHONY: help
help:
@echo "🛠️ github Commands:\n"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: install
install: ## Install the uv environment and install all packages with dependencies
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras --no-sources
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv"
.PHONY: install-local
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv"
.PHONY: build
build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
uv build
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning dist directory"
rm -rf dist
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
coverage report
@echo "Generating coverage report"
coverage html
.PHONY: bump-version
bump-version: ## Bump the version in the pyproject.toml file by a patch version
@echo "🚀 Bumping version in pyproject.toml"
uv version --bump patch
.PHONY: check
check: ## Run code quality tools.
@echo "🚀 Linting code: Running pre-commit"
@uv run pre-commit run -a
@echo "🚀 Static type checking: Running mypy"
@uv run mypy --config-file=pyproject.toml

View file

@ -1,29 +0,0 @@
import sys
from typing import cast
from arcade_mcp_server import MCPApp
from arcade_mcp_server.mcp_app import TransportType
import arcade_mongodb
app = MCPApp(
name="MongoDB",
instructions=(
"Use this server when you need to interact with MongoDB to help users "
"query, explore, and manage their MongoDB databases and collections."
),
)
app.add_tools_from_module(arcade_mongodb)
def main() -> None:
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
app.run(transport=cast(TransportType, transport), host=host, port=port)
if __name__ == "__main__":
main()

View file

@ -1,118 +0,0 @@
from typing import Any, ClassVar
from arcade_mcp_server.exceptions import RetryableToolError
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from pymongo.errors import ServerSelectionTimeoutError
MAX_RECORDS_RETURNED = 1000
TEST_QUERY = {"ping": 1}
class DatabaseEngine:
_instance: ClassVar[None] = None
_clients: ClassVar[dict[str, AsyncIOMotorClient]] = {}
@classmethod
async def get_instance(cls, connection_string: str) -> AsyncIOMotorClient:
key = connection_string
if key not in cls._clients:
cls._clients[key] = AsyncIOMotorClient(connection_string)
# try a simple query to see if the connection is valid
try:
admin_db = cls._clients[key].admin
await admin_db.command(TEST_QUERY)
return cls._clients[key]
except ServerSelectionTimeoutError:
# close and try again
cls._clients[key].close()
cls._clients[key] = AsyncIOMotorClient(connection_string)
try:
admin_db = cls._clients[key].admin
await admin_db.command(TEST_QUERY)
return cls._clients[key]
except Exception as e:
raise RetryableToolError(
f"Connection failed: {e}",
developer_message="Connection to MongoDB failed.",
additional_prompt_content="Check the connection string and try again.",
) from e
@classmethod
async def get_database(cls, connection_string: str, database_name: str) -> Any:
client = await cls.get_instance(connection_string)
class DatabaseContextManager:
def __init__(self, client: AsyncIOMotorClient, database_name: str) -> None:
self.client = client
self.database_name = database_name
self.database = client[database_name]
async def __aenter__(self) -> AsyncIOMotorDatabase:
return self.database
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# Connection cleanup is handled by the client cache
pass
return DatabaseContextManager(client, database_name)
@classmethod
async def cleanup(cls) -> None:
"""Clean up all cached clients. Call this when shutting down."""
for client in cls._clients.values():
client.close()
cls._clients.clear()
@classmethod
def clear_cache(cls) -> None:
"""Clear the client cache without closing clients. Use with caution."""
cls._clients.clear()
@classmethod
def sanitize_query_params(
cls,
database_name: str,
collection_name: str,
filter_dict: dict[str, Any] | None,
projection: dict[str, Any] | None,
sort: list[dict[str, Any]] | None,
limit: int,
skip: int,
) -> tuple[
str, str, dict[str, Any], dict[str, Any] | None, list[dict[str, Any]] | None, int, int
]:
if not database_name:
raise RetryableToolError(
"Database name is required.",
developer_message="Database name cannot be empty.",
)
if not collection_name:
raise RetryableToolError(
"Collection name is required.",
developer_message="Collection name cannot be empty.",
)
if filter_dict is None:
filter_dict = {}
if limit > MAX_RECORDS_RETURNED:
raise RetryableToolError(
f"Limit is too high. Maximum is {MAX_RECORDS_RETURNED}.",
)
if skip < 0:
raise RetryableToolError(
"Skip must be greater than or equal to 0.",
developer_message="Skip must be greater than or equal to 0.",
)
if limit <= 0:
raise RetryableToolError(
"Limit must be greater than 0.",
developer_message="Limit must be greater than 0.",
)
return database_name, collection_name, filter_dict, projection, sort, limit, skip

View file

@ -1,434 +0,0 @@
import json
from typing import Annotated, Any
from arcade_mcp_server import Context, tool
from arcade_mcp_server.exceptions import RetryableToolError
from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata
from ..database_engine import MAX_RECORDS_RETURNED, DatabaseEngine
from .utils import (
_infer_schema_from_docs,
_parse_json_list_parameter,
_parse_json_parameter,
_serialize_document,
)
# class UserStatus(str, Enum):
# """User status enumeration."""
# ACTIVE = "active"
# INACTIVE = "inactive"
# BANNED = "banned"
@tool(
requires_secrets=["MONGODB_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def discover_databases(
context: Context,
) -> list[str]:
"""Discover all the databases in the MongoDB instance."""
client = await DatabaseEngine.get_instance(context.get_secret("MONGODB_CONNECTION_STRING"))
databases = await client.list_database_names()
# Filter out admin and config databases by default
databases = [db for db in databases if db not in ["admin", "config", "local"]]
return databases
@tool(
requires_secrets=["MONGODB_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def discover_collections(
context: Context,
database_name: Annotated[str, "The database name to discover collections in"],
) -> list[str]:
"""Discover all the collections in the MongoDB database when the list of collections is not known.
ALWAYS use this tool before any other tool that requires a collection name.
"""
async with await DatabaseEngine.get_database(
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
) as db:
collections = await db.list_collection_names()
return list(collections)
@tool(
requires_secrets=["MONGODB_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def get_collection_schema(
context: Context,
database_name: Annotated[str, "The database name to get the collection schema of"],
collection_name: Annotated[str, "The collection to get the schema of"],
sample_size: Annotated[
int,
f"The number of documents to sample for schema discovery (default: {MAX_RECORDS_RETURNED})",
] = MAX_RECORDS_RETURNED,
) -> dict[str, Any]:
"""
Get the schema/structure of a MongoDB collection by sampling documents.
Since MongoDB is schema-less, this tool samples a configurable number of documents
to infer the schema structure and data types.
This tool should ALWAYS be used before executing any query. All collections in the query must be discovered first using the <discover_collections> tool.
"""
async with await DatabaseEngine.get_database(
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
) as db:
collection = db[collection_name]
# Sample documents at random to infer schema
# Use MongoDB's $sample aggregation to get random documents
sample_docs = []
async for doc in collection.aggregate([{"$sample": {"size": sample_size}}]):
sample_docs.append(doc)
if not sample_docs:
return {"message": "Collection is empty", "schema": {}}
# Infer schema from sampled documents
schema = _infer_schema_from_docs(sample_docs)
return {
"total_documents_sampled": len(sample_docs),
"sample_size_requested": sample_size,
"schema": schema,
}
@tool(
requires_secrets=["MONGODB_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def find_documents(
context: Context,
database_name: Annotated[str, "The database name to query"],
collection_name: Annotated[str, "The collection name to query"],
filter_dict: Annotated[
str | None,
'MongoDB filter/query as JSON string. Leave None for no filter (find all documents). Example: \'{"status": "active", "age": {"$gte": 18}}\'',
] = None,
projection: Annotated[
str | None,
'Fields to include/exclude as JSON string. Use 1 to include, 0 to exclude. Example: \'{"name": 1, "email": 1, "_id": 0}\'. Leave None to include all fields.',
] = None,
sort: Annotated[
list[str] | None,
'Sort criteria as list of JSON strings, each containing \'field\' and \'direction\' keys. Use 1 for ascending, -1 for descending. Example: [\'{"field": "name", "direction": 1}\', \'{"field": "created_at", "direction": -1}\']',
] = None,
limit: Annotated[
int,
f"The maximum number of documents to return. Default: {MAX_RECORDS_RETURNED}.",
] = MAX_RECORDS_RETURNED,
skip: Annotated[int, "The number of documents to skip. Default: 0."] = 0,
) -> list[str]:
"""
Find documents in a MongoDB collection.
ONLY use this tool if you have already loaded the schema of the collection you need to query.
Use the <get_collection_schema> tool to load the schema if not already known.
Returns a list of JSON strings, where each string represents a document from the collection (tools cannot return complex types).
When running queries, follow these rules which will help avoid errors:
* Always specify projection to limit fields returned if you don't need all data.
* Always sort your results by the most important fields first. If you aren't sure, sort by '_id'.
* Use appropriate MongoDB query operators for complex filtering ($gte, $lte, $in, $regex, etc.).
* Be mindful of case sensitivity when querying string fields.
* Use indexes when possible (typically on _id and commonly queried fields).
"""
# Initialize variables to avoid UnboundLocalError in exception handler
parsed_filter = None
parsed_projection = None
parsed_sort = None
try:
# Parse JSON string inputs
parsed_filter = _parse_json_parameter(filter_dict, "filter_dict")
parsed_projection = _parse_json_parameter(projection, "projection")
parsed_sort = _parse_json_list_parameter(sort, "sort")
(
database_name,
collection_name,
parsed_filter,
parsed_projection,
parsed_sort,
limit,
skip,
) = DatabaseEngine.sanitize_query_params(
database_name=database_name,
collection_name=collection_name,
filter_dict=parsed_filter,
projection=parsed_projection,
sort=parsed_sort,
limit=limit,
skip=skip,
)
async with await DatabaseEngine.get_database(
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
) as db:
collection = db[collection_name]
# Build the query
cursor = collection.find(parsed_filter, parsed_projection)
if parsed_sort:
# Convert list of dicts to list of tuples for MongoDB sort
sort_tuples = [(str(item["field"]), int(item["direction"])) for item in parsed_sort]
cursor = cursor.sort(sort_tuples)
cursor = cursor.skip(skip).limit(limit)
# Execute query and collect results
documents = []
async for doc in cursor:
# Convert ObjectId and other non-serializable types to strings
doc = _serialize_document(doc)
documents.append(json.dumps(doc))
return documents
except RetryableToolError:
# Re-raise RetryableToolError as-is to preserve JSON validation messages
raise
except Exception as e:
raise RetryableToolError(
f"Query failed: {e}",
developer_message=f"Query failed with parameters: database_name={database_name}, collection_name={collection_name}, filter_dict={parsed_filter}, projection={parsed_projection}, sort={parsed_sort}, limit={limit}, skip={skip}.",
additional_prompt_content="Load the collection schema <get_collection_schema> or use the <discover_collections> tool to discover the collections and try again.",
retry_after_ms=10,
) from e
@tool(
requires_secrets=["MONGODB_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def count_documents(
context: Context,
database_name: Annotated[str, "The database name to query"],
collection_name: Annotated[str, "The collection name to query"],
filter_dict: Annotated[
str | None,
'MongoDB filter/query as JSON string. Leave None for no filter (count all documents). Example: \'{"status": "active"}\'',
] = None,
) -> int:
"""Count documents in a MongoDB collection matching the given filter."""
parsed_filter = None
try:
# Parse JSON string input
parsed_filter = _parse_json_parameter(filter_dict, "filter_dict") or {}
async with await DatabaseEngine.get_database(
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
) as db:
collection = db[collection_name]
count = await collection.count_documents(parsed_filter)
return int(count)
except RetryableToolError:
# Re-raise RetryableToolError as-is to preserve JSON validation messages
raise
except Exception as e:
raise RetryableToolError(
f"Count query failed: {e}",
developer_message=f"Count query failed with parameters: database_name={database_name}, collection_name={collection_name}, filter_dict={parsed_filter}.",
additional_prompt_content="Check the collection name and filter criteria and try again.",
retry_after_ms=10,
) from e
@tool(
requires_secrets=["MONGODB_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def aggregate_documents(
context: Context,
database_name: Annotated[str, "The database name to query"],
collection_name: Annotated[str, "The collection name to query"],
pipeline: Annotated[
list[str],
'MongoDB aggregation pipeline as a list of JSON strings, each representing a stage. Example: [\'{"$match": {"status": "active"}}\', \'{"$group": {"_id": "$category", "count": {"$sum": 1}}}\']',
],
limit: Annotated[
int,
f"The maximum number of results to return from the aggregation. Default: {MAX_RECORDS_RETURNED}.",
] = MAX_RECORDS_RETURNED,
) -> list[str]:
"""
Execute a MongoDB aggregation pipeline on a collection.
ONLY use this tool if you have already loaded the schema of the collection you need to query.
Use the <get_collection_schema> tool to load the schema if not already known.
Returns a list of JSON strings, where each string represents a result document from the aggregation (tools cannot return complex types).
Aggregation pipelines allow for complex data processing including:
* $match - filter documents
* $group - group documents and perform calculations
* $project - reshape documents
* $sort - sort documents
* $limit - limit results
* $lookup - join with other collections
* And many more stages
"""
parsed_pipeline = None
try:
# Parse JSON string inputs
parsed_pipeline = _parse_json_list_parameter(pipeline, "pipeline")
if parsed_pipeline is None:
raise RetryableToolError( # noqa: TRY301
"Pipeline cannot be empty",
developer_message="The pipeline parameter is required and cannot be None",
)
async with await DatabaseEngine.get_database(
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
) as db:
collection = db[collection_name]
# Add limit to pipeline if not already present
pipeline_with_limit = parsed_pipeline.copy()
has_limit = any("$limit" in stage for stage in pipeline_with_limit)
if not has_limit:
pipeline_with_limit.append({"$limit": limit})
# Execute aggregation
cursor = collection.aggregate(pipeline_with_limit)
documents = []
async for doc in cursor:
# Convert ObjectId and other non-serializable types to strings
doc = _serialize_document(doc)
documents.append(json.dumps(doc))
return documents
except RetryableToolError:
# Re-raise RetryableToolError as-is to preserve JSON validation messages
raise
except Exception as e:
raise RetryableToolError(
f"Aggregation query failed: {e}",
developer_message=f"Aggregation query failed with parameters: database_name={database_name}, collection_name={collection_name}, pipeline={parsed_pipeline}, limit={limit}.",
additional_prompt_content="Check the aggregation pipeline syntax and collection schema, then try again.",
retry_after_ms=10,
) from e
# @tool(requires_secrets=["MONGODB_CONNECTION_STRING"])
# async def update_user_status(
# context: ToolContext,
# database_name: Annotated[str, "The database name containing the users collection"],
# collection_name: Annotated[str, "The collection name containing user documents"],
# user_id: Annotated[str, "The _id of the user to update"],
# status: Annotated[UserStatus, "The new status for the user"],
# ) -> dict[str, Any]:
# """
# [CUSTOM TOOL]
# Update the status of a user in the MongoDB collection.
# This tool updates a user document by setting the status field to the specified value.
# The status must be one of: active, inactive, or banned.
# Returns information about the update operation including the number of documents modified.
# """
# try:
# async with await DatabaseEngine.get_database(
# context.get_secret("MONGODB_CONNECTION_STRING"), database_name
# ) as db:
# collection = db[collection_name]
# # cast the user_id to int if it looks like an integer
# if isinstance(user_id, str) and user_id.isdigit():
# user_id = int(user_id)
# result = await collection.update_one(
# {"_id": user_id}, {"$set": {"status": status.value}}
# )
# print(result)
# if result.matched_count == 0:
# return {
# "success": False,
# "message": f"No user found with _id: {user_id}",
# "matched_count": 0,
# "modified_count": 0,
# }
# return {
# "success": True,
# "message": f"User status updated to '{status.value}'",
# "user_id": user_id,
# "new_status": status.value,
# "matched_count": result.matched_count,
# "modified_count": result.modified_count,
# }
# except Exception as e:
# raise RetryableToolError(
# f"Failed to update user status: {e}",
# developer_message=f"Update operation failed with parameters: database_name={database_name}, collection_name={collection_name}, user_id={user_id}, status={status}.",
# additional_prompt_content="Check the database name, collection name, and user ID, then try again.",
# retry_after_ms=10,
# ) from e

View file

@ -1,281 +0,0 @@
import json
from datetime import datetime
from typing import Any
from arcade_mcp_server.exceptions import RetryableToolError
from bson import ObjectId
def _validate_no_write_operations(obj: Any, parameter_name: str, path: str = "") -> None:
"""
Recursively validate that an object doesn't contain MongoDB write operations.
Args:
obj: The object to validate
parameter_name: Name of the parameter for error messages
path: Current path in the object (for nested validation)
Raises:
RetryableToolError: If write operations are detected
"""
# MongoDB write/update operators that should be blocked
WRITE_OPERATORS = {
# Update operators
"$set",
"$unset",
"$inc",
"$mul",
"$rename",
"$min",
"$max",
"$currentDate",
"$addToSet",
"$pop",
"$pull",
"$push",
"$pullAll",
"$each",
"$slice",
"$sort",
"$position",
"$bit",
"$isolated",
# Array update operators
"$",
"$[]",
"$[<identifier>]",
# Pipeline update operators
"$addFields",
"$replaceRoot",
"$replaceWith",
# Aggregation stages that can modify (in case they're misused)
"$out",
"$merge",
# Other potentially dangerous operators
"$where", # Can execute JavaScript
}
if isinstance(obj, dict):
for key, value in obj.items():
current_path = f"{path}.{key}" if path else key
# Special check for $where operator which can execute JavaScript (check this first)
if key == "$where":
raise RetryableToolError(
f"JavaScript execution operator '$where' not allowed in {parameter_name}",
developer_message=f"Found '$where' operator at path '{current_path}' in parameter '{parameter_name}'. JavaScript execution is not allowed for security reasons.",
additional_prompt_content=f"The {parameter_name} parameter cannot use the $where operator. Use other query operators instead.",
)
# Check if this key is a write operator
if key in WRITE_OPERATORS:
raise RetryableToolError(
f"Write operation '{key}' not allowed in {parameter_name}",
developer_message=f"Found write operation '{key}' at path '{current_path}' in parameter '{parameter_name}'. Only read operations are allowed.",
additional_prompt_content=f"The {parameter_name} parameter cannot contain write operations like '{key}'. Use only query/read operations such as $match, $gte, $lte, $in, $regex, etc.",
)
# Recursively validate nested objects
_validate_no_write_operations(value, parameter_name, current_path)
elif isinstance(obj, list):
for i, item in enumerate(obj):
current_path = f"{path}[{i}]" if path else f"[{i}]"
_validate_no_write_operations(item, parameter_name, current_path)
def _parse_json_parameter(
json_string: str | None, parameter_name: str, validate_read_only: bool = True
) -> Any | None:
"""
Parse a JSON string parameter with proper error handling and optional write operation validation.
Args:
json_string: The JSON string to parse (can be None)
parameter_name: Name of the parameter for error messages
validate_read_only: Whether to validate that no write operations are present
Returns:
Parsed JSON object or None if json_string is None
Raises:
RetryableToolError: If JSON parsing fails or write operations are detected
"""
if json_string is None:
return None
try:
parsed_obj = json.loads(json_string)
# Validate that no write operations are present
if validate_read_only and parsed_obj is not None:
_validate_no_write_operations(parsed_obj, parameter_name)
except json.JSONDecodeError as e:
raise RetryableToolError(
f"Invalid JSON in {parameter_name}: {e}",
developer_message=f"Failed to parse JSON string for parameter '{parameter_name}': {json_string}. Error: {e}",
additional_prompt_content=f"Please provide valid JSON for the {parameter_name} parameter. Check for proper escaping of quotes and valid JSON syntax.",
) from e
else:
return parsed_obj
def _validate_aggregation_pipeline(pipeline: list[Any], parameter_name: str) -> None:
"""
Validate that an aggregation pipeline only contains read operations.
Args:
pipeline: The aggregation pipeline to validate
parameter_name: Name of the parameter for error messages
Raises:
RetryableToolError: If write operations are detected in the pipeline
"""
# MongoDB aggregation stages that can modify data
WRITE_STAGES = {
"$out",
"$merge", # These stages write to collections
}
# Aggregation stages that are potentially dangerous
DANGEROUS_STAGES = {
"$where", # Can execute JavaScript
}
for i, stage in enumerate(pipeline):
if isinstance(stage, dict):
for stage_name in stage:
if stage_name in WRITE_STAGES:
raise RetryableToolError(
f"Write stage '{stage_name}' not allowed in {parameter_name}",
developer_message=f"Found write stage '{stage_name}' at pipeline index {i} in parameter '{parameter_name}'. Only read operations are allowed.",
additional_prompt_content=f"The {parameter_name} parameter cannot contain write stages like '{stage_name}'. Use only read stages such as $match, $group, $project, $sort, $limit, etc.",
)
if stage_name in DANGEROUS_STAGES:
raise RetryableToolError(
f"Dangerous stage '{stage_name}' not allowed in {parameter_name}",
developer_message=f"Found dangerous stage '{stage_name}' at pipeline index {i} in parameter '{parameter_name}'. JavaScript execution is not allowed for security reasons.",
additional_prompt_content=f"The {parameter_name} parameter cannot use the {stage_name} stage. Use other aggregation stages instead.",
)
# Also validate the stage content for write operations
_validate_no_write_operations(
stage[stage_name], f"{parameter_name}[{i}].{stage_name}"
)
def _parse_json_list_parameter(
json_strings: list[str] | None, parameter_name: str, validate_read_only: bool = True
) -> list[Any] | None:
"""
Parse a list of JSON strings with proper error handling and optional write operation validation.
Args:
json_strings: List of JSON strings to parse (can be None)
parameter_name: Name of the parameter for error messages
validate_read_only: Whether to validate that no write operations are present
Returns:
List of parsed JSON objects or None if json_strings is None
Raises:
RetryableToolError: If JSON parsing fails for any string or write operations are detected
"""
if json_strings is None:
return None
try:
parsed_list = [json.loads(json_str) for json_str in json_strings]
# Validate that no write operations are present
if validate_read_only and parsed_list is not None:
# Special handling for pipeline parameters
if parameter_name == "pipeline":
_validate_aggregation_pipeline(parsed_list, parameter_name)
else:
# For non-pipeline lists, validate each item
for i, item in enumerate(parsed_list):
_validate_no_write_operations(item, f"{parameter_name}[{i}]")
except json.JSONDecodeError as e:
raise RetryableToolError(
f"Invalid JSON in {parameter_name}: {e}",
developer_message=f"Failed to parse JSON string list for parameter '{parameter_name}': {json_strings}. Error: {e}",
additional_prompt_content=f"Please provide valid JSON strings for the {parameter_name} parameter. Each string must be valid JSON with proper escaping of quotes.",
) from e
else:
return parsed_list
def _infer_schema_from_docs(docs: list[dict[str, Any]]) -> dict[str, Any]:
"""Infer schema structure from a list of documents."""
schema: dict[str, Any] = {}
for doc in docs:
_update_schema_with_doc(schema, doc)
# Convert sets to lists for serialization
for key in schema:
if isinstance(schema[key]["types"], set):
schema[key]["types"] = list(schema[key]["types"])
return schema
def _update_schema_with_doc(schema: dict[str, Any], doc: dict[str, Any], prefix: str = "") -> None:
"""Recursively update schema with document structure."""
for key, value in doc.items():
full_key = f"{prefix}.{key}" if prefix else key
if full_key not in schema:
schema[full_key] = {
"types": set(),
"sample_values": [],
"null_count": 0,
"total_count": 0,
}
schema[full_key]["total_count"] += 1
if value is None:
schema[full_key]["null_count"] += 1
schema[full_key]["types"].add("null")
else:
value_type = type(value).__name__
schema[full_key]["types"].add(value_type)
# Store sample values (limit to 3 unique samples)
if (
len(schema[full_key]["sample_values"]) < 3
and value not in schema[full_key]["sample_values"]
):
schema[full_key]["sample_values"].append(value)
# Handle nested objects
if isinstance(value, dict):
_update_schema_with_doc(schema, value, full_key)
elif isinstance(value, list) and value and isinstance(value[0], dict):
# Handle arrays of objects by sampling the first few
for i, item in enumerate(value[:3]): # Sample first 3 array items
if isinstance(item, dict):
_update_schema_with_doc(schema, item, f"{full_key}[{i}]")
def _serialize_document(doc: dict[str, Any]) -> dict[str, Any]:
"""Convert MongoDB document to JSON-serializable format."""
if isinstance(doc, dict):
result = {}
for key, value in doc.items():
result[key] = _serialize_document(value)
return result
elif isinstance(doc, list):
return [_serialize_document(item) for item in doc]
elif isinstance(doc, ObjectId):
return str(doc)
elif isinstance(doc, datetime):
return doc.isoformat()
else:
return doc

View file

@ -1,190 +0,0 @@
# RUN ME WITH `uv run arcade evals evals --host api.arcade.dev`
import arcade_mongodb
from arcade_core import ToolCatalog
from arcade_evals import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
SimilarityCritic,
tool_eval,
)
from arcade_mongodb.tools.mongodb import (
aggregate_documents,
count_documents,
discover_collections,
discover_databases,
find_documents,
get_collection_schema,
)
# Evaluation rubric
rubric = EvalRubric(
fail_threshold=0.85,
warn_threshold=0.95,
)
catalog = ToolCatalog()
catalog.add_module(arcade_mongodb)
@tool_eval()
def mongodb_eval_suite() -> EvalSuite:
suite = EvalSuite(
name="MongoDB Tools Evaluation",
system_message=(
"You are an AI assistant with access to MongoDB tools. "
"Use them to help the user with their tasks."
),
catalog=catalog,
rubric=rubric,
)
suite.add_case(
name="Discover databases",
user_message="What databases are available in my MongoDB instance?",
expected_tool_calls=[
ExpectedToolCall(func=discover_databases, args={}),
],
rubric=rubric,
)
suite.add_case(
name="Discover collections",
user_message="What collections are in the 'admin' database?",
expected_tool_calls=[
ExpectedToolCall(func=discover_collections, args={"database_name": "admin"}),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="database_name", weight=1.0),
],
)
suite.add_case(
name="Get collection schema (single tool call)",
user_message="Get the schema of the 'system.users' collection in the 'admin' database.",
expected_tool_calls=[
ExpectedToolCall(
func=get_collection_schema,
args={"database_name": "admin", "collection_name": "system.users"},
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="database_name", weight=0.5),
BinaryCritic(critic_field="collection_name", weight=0.5),
],
)
suite.add_case(
name="Find documents (direct call)",
user_message="Find documents in the 'startup_log' collection of the 'local' database, limited to 5 results.",
additional_messages=[
{
"role": "user",
"content": "You can call find_documents directly without discovering collections first for this test.",
}
],
expected_tool_calls=[
ExpectedToolCall(
func=find_documents,
args={
"database_name": "local",
"collection_name": "startup_log",
"limit": 5,
},
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="database_name", weight=0.33),
BinaryCritic(critic_field="collection_name", weight=0.33),
BinaryCritic(critic_field="limit", weight=0.34),
],
)
suite.add_case(
name="Count documents",
user_message="Count all documents in the 'startup_log' collection of the 'local' database.",
additional_messages=[
{
"role": "user",
"content": "You can call count_documents directly without discovering collections first for this test.",
}
],
expected_tool_calls=[
ExpectedToolCall(
func=count_documents,
args={
"database_name": "local",
"collection_name": "startup_log",
},
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="database_name", weight=0.5),
BinaryCritic(critic_field="collection_name", weight=0.5),
],
)
suite.add_case(
name="Count documents with filter",
user_message="Count documents in the 'startup_log' collection of the 'local' database where the level is 'INFO'.",
additional_messages=[
{
"role": "user",
"content": "You can call count_documents directly without discovering collections first for this test.",
}
],
expected_tool_calls=[
ExpectedToolCall(
func=count_documents,
args={
"database_name": "local",
"collection_name": "startup_log",
"filter_dict": '{"level": "INFO"}',
},
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="database_name", weight=0.25),
BinaryCritic(critic_field="collection_name", weight=0.25),
SimilarityCritic(critic_field="filter_dict", weight=0.5),
],
)
suite.add_case(
name="Aggregate documents",
user_message="Group documents in the 'startup_log' collection of the 'local' database by level and count them.",
additional_messages=[
{
"role": "user",
"content": "You can call aggregate_documents directly without discovering collections first for this test.",
}
],
expected_tool_calls=[
ExpectedToolCall(
func=aggregate_documents,
args={
"database_name": "local",
"collection_name": "startup_log",
"pipeline": [
'{"$group": {"_id": "$level", "count": {"$sum": 1}}}',
],
},
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="database_name", weight=0.2),
BinaryCritic(critic_field="collection_name", weight=0.2),
SimilarityCritic(critic_field="pipeline", weight=0.6),
],
)
return suite

View file

@ -1,62 +0,0 @@
[build-system]
requires = [ "hatchling",]
build-backend = "hatchling.build"
[project]
name = "arcade_mongodb"
version = "0.3.0"
description = "Tools to query and explore a MongoDB database"
requires-python = ">=3.10"
dependencies = [
"arcade-mcp-server>=1.17.0,<2.0.0",
"pymongo>=4.10.1",
"pydantic>=2.11.7",
"motor>=3.6.0",
]
[[project.authors]]
name = "evantahler"
email = "support@arcade.dev"
[project.optional-dependencies]
dev = [
"arcade-mcp[all]>=1.2.0,<2.0.0",
"pytest>=8.3.0,<8.4.0",
"pytest-cov>=4.0.0,<4.1.0",
"pytest-mock>=3.11.1,<3.12.0",
"pytest-asyncio>=0.24.0,<0.25.0",
"mypy>=1.5.1,<1.6.0",
"pre-commit>=3.4.0,<3.5.0",
"tox>=4.11.1,<4.12.0",
"ruff>=0.7.4,<0.8.0",
]
[project.scripts]
arcade-mongodb = "arcade_mongodb.__main__:main"
arcade_mongodb = "arcade_mongodb.__main__:main"
# Use local path sources for arcade libs when working locally
[tool.uv.sources]
arcade-mcp = { path = "../../", editable = true }
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
[tool.mypy]
files = [ "arcade_mongodb/**/*.py",]
python_version = "3.10"
disallow_untyped_defs = "True"
disallow_any_unimported = "True"
no_implicit_optional = "True"
check_untyped_defs = "True"
warn_return_any = "True"
warn_unused_ignores = "True"
show_error_codes = "True"
ignore_missing_imports = "True"
[tool.pytest.ini_options]
testpaths = [ "tests",]
asyncio_default_fixture_loop_scope = "function"
[tool.coverage.report]
skip_empty = true
[tool.hatch.build.targets.wheel]
packages = [ "arcade_mongodb",]

View file

@ -1,45 +0,0 @@
import os
import shutil
import subprocess
from os import environ
import pytest_asyncio
from arcade_mongodb.database_engine import DatabaseEngine
TEST_MONGODB_CONNECTION_STRING = (
environ.get("TEST_MONGODB_CONNECTION_STRING") or "mongodb://localhost:27017"
)
@pytest_asyncio.fixture(autouse=True)
async def restore_database():
"""Restore the database from the dump before each test."""
dump_file = f"{os.path.dirname(__file__)}/dump.js"
# Execute the MongoDB dump script to restore test data
mongosh_path = shutil.which("mongosh")
if not mongosh_path:
raise RuntimeError("mongosh executable not found in PATH")
result = subprocess.run(
[mongosh_path, TEST_MONGODB_CONNECTION_STRING, dump_file],
check=True,
capture_output=True,
text=True,
)
if result.returncode != 0:
print(f"Error loading test data: {result.stderr}")
raise RuntimeError(f"Failed to load test data: {result.stderr}")
yield # This allows tests to run
# Optional cleanup could go here if needed
@pytest_asyncio.fixture(autouse=True)
async def cleanup_engines():
"""Clean up database engines after each test to prevent connection leaks."""
yield
await DatabaseEngine.cleanup()

View file

@ -1,378 +0,0 @@
// MongoDB test data dump - equivalent to PostgreSQL dump.sql
// This script sets up test data for the MongoDB toolkit
// Switch to test database
use('test_database');
// Clear existing data
db.users.drop();
db.messages.drop();
// Create users collection with data
db.users.insertMany([
{
_id: 1,
name: 'Alice',
email: 'alice@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
created_at: new Date('2024-09-01T20:49:38.759Z'),
updated_at: new Date('2024-09-02T03:49:39.927Z'),
status: 'active'
},
{
_id: 2,
name: 'Bob',
email: 'bob@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
created_at: new Date('2024-09-02T17:49:23.377Z'),
updated_at: new Date('2024-09-02T17:49:23.377Z'),
status: 'active'
},
{
_id: 3,
name: 'Charlie',
email: 'charlie@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
created_at: new Date('2024-09-03T10:30:15.123Z'),
updated_at: new Date('2024-09-03T10:30:15.123Z'),
status: 'active'
},
{
_id: 4,
name: 'Diana',
email: 'diana@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123',
created_at: new Date('2024-09-04T14:20:30.654Z'),
updated_at: new Date('2024-09-04T14:20:30.654Z'),
status: 'active'
},
{
_id: 5,
name: 'Evan',
email: 'evan@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456',
created_at: new Date('2024-09-05T09:15:45.987Z'),
updated_at: new Date('2024-09-05T09:15:45.987Z'),
status: 'active'
},
{
_id: 6,
name: 'Fiona',
email: 'fiona@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789',
created_at: new Date('2024-09-06T16:45:12.345Z'),
updated_at: new Date('2024-09-06T16:45:12.345Z'),
status: 'active'
},
{
_id: 7,
name: 'George',
email: 'george@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012',
created_at: new Date('2024-09-07T11:30:25.876Z'),
updated_at: new Date('2024-09-07T11:30:25.876Z'),
status: 'active'
},
{
_id: 8,
name: 'Helen',
email: 'helen@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345',
created_at: new Date('2024-09-08T13:25:40.234Z'),
updated_at: new Date('2024-09-08T13:25:40.234Z'),
status: 'active'
},
{
_id: 9,
name: 'Ian',
email: 'ian@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678',
created_at: new Date('2024-09-09T08:40:55.765Z'),
updated_at: new Date('2024-09-09T08:40:55.765Z'),
status: 'active'
},
{
_id: 10,
name: 'Julia',
email: 'julia@example.com',
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901',
created_at: new Date('2024-09-10T15:55:18.123Z'),
updated_at: new Date('2024-09-10T15:55:18.123Z'),
status: 'active'
}
]);
// Create messages collection with data
db.messages.insertMany([
// User 1 (Alice) - 3 messages
{
_id: 1,
body: 'Hello everyone!',
user_id: 1,
created_at: new Date('2025-01-10T10:00:00.000Z'),
updated_at: new Date('2025-01-10T10:00:00.000Z')
},
{
_id: 2,
body: 'How is everyone doing today?',
user_id: 1,
created_at: new Date('2025-01-10T11:30:00.000Z'),
updated_at: new Date('2025-01-10T11:30:00.000Z')
},
{
_id: 3,
body: 'Great to see you all here!',
user_id: 1,
created_at: new Date('2025-01-10T14:15:00.000Z'),
updated_at: new Date('2025-01-10T14:15:00.000Z')
},
// User 2 (Bob) - 2 messages
{
_id: 4,
body: 'Hi Alice! Doing well, thanks for asking.',
user_id: 2,
created_at: new Date('2025-01-10T11:35:00.000Z'),
updated_at: new Date('2025-01-10T11:35:00.000Z')
},
{
_id: 5,
body: 'Anyone up for a game later?',
user_id: 2,
created_at: new Date('2025-01-10T16:20:00.000Z'),
updated_at: new Date('2025-01-10T16:20:00.000Z')
},
// User 3 (Charlie) - 3 messages
{
_id: 6,
body: 'Count me in for the game!',
user_id: 3,
created_at: new Date('2025-01-10T16:25:00.000Z'),
updated_at: new Date('2025-01-10T16:25:00.000Z')
},
{
_id: 7,
body: 'What time works for everyone?',
user_id: 3,
created_at: new Date('2025-01-10T16:30:00.000Z'),
updated_at: new Date('2025-01-10T16:30:00.000Z')
},
{
_id: 8,
body: 'I can play around 8 PM',
user_id: 3,
created_at: new Date('2025-01-10T17:00:00.000Z'),
updated_at: new Date('2025-01-10T17:00:00.000Z')
},
// User 4 (Diana) - 2 messages
{
_id: 9,
body: '8 PM works for me too!',
user_id: 4,
created_at: new Date('2025-01-10T17:05:00.000Z'),
updated_at: new Date('2025-01-10T17:05:00.000Z')
},
{
_id: 10,
body: 'What game should we play?',
user_id: 4,
created_at: new Date('2025-01-10T17:10:00.000Z'),
updated_at: new Date('2025-01-10T17:10:00.000Z')
},
// User 5 (Evan) - 13 messages (including 10 additional ones)
{
_id: 11,
body: 'I suggest we try the new arcade game!',
user_id: 5,
created_at: new Date('2025-01-10T17:15:00.000Z'),
updated_at: new Date('2025-01-10T17:15:00.000Z')
},
{
_id: 12,
body: 'It has great multiplayer features',
user_id: 5,
created_at: new Date('2025-01-10T17:20:00.000Z'),
updated_at: new Date('2025-01-10T17:20:00.000Z')
},
{
_id: 13,
body: 'Perfect timing for a weekend session',
user_id: 5,
created_at: new Date('2025-01-10T18:00:00.000Z'),
updated_at: new Date('2025-01-10T18:00:00.000Z')
},
{
_id: 26,
body: 'Just finished setting up the game server!',
user_id: 5,
created_at: new Date('2025-01-10T20:00:00.000Z'),
updated_at: new Date('2025-01-10T20:00:00.000Z')
},
{
_id: 27,
body: 'Everyone should be able to connect now',
user_id: 5,
created_at: new Date('2025-01-10T20:05:00.000Z'),
updated_at: new Date('2025-01-10T20:05:00.000Z')
},
{
_id: 28,
body: 'I added some custom maps too',
user_id: 5,
created_at: new Date('2025-01-10T20:10:00.000Z'),
updated_at: new Date('2025-01-10T20:10:00.000Z')
},
{
_id: 29,
body: 'The graphics look amazing on this new version',
user_id: 5,
created_at: new Date('2025-01-10T20:15:00.000Z'),
updated_at: new Date('2025-01-10T20:15:00.000Z')
},
{
_id: 30,
body: 'Hope you all enjoy the new features',
user_id: 5,
created_at: new Date('2025-01-10T20:20:00.000Z'),
updated_at: new Date('2025-01-10T20:20:00.000Z')
},
{
_id: 31,
body: 'I also set up a leaderboard system',
user_id: 5,
created_at: new Date('2025-01-10T20:25:00.000Z'),
updated_at: new Date('2025-01-10T20:25:00.000Z')
},
{
_id: 32,
body: 'We can track high scores now',
user_id: 5,
created_at: new Date('2025-01-10T20:30:00.000Z'),
updated_at: new Date('2025-01-10T20:30:00.000Z')
},
{
_id: 33,
body: 'The game supports up to 8 players simultaneously',
user_id: 5,
created_at: new Date('2025-01-10T20:35:00.000Z'),
updated_at: new Date('2025-01-10T20:35:00.000Z')
},
{
_id: 34,
body: 'I tested it earlier and it runs smoothly',
user_id: 5,
created_at: new Date('2025-01-10T20:40:00.000Z'),
updated_at: new Date('2025-01-10T20:40:00.000Z')
},
{
_id: 35,
body: 'Cannot wait to see everyone online tonight!',
user_id: 5,
created_at: new Date('2025-01-10T20:45:00.000Z'),
updated_at: new Date('2025-01-10T20:45:00.000Z')
},
// User 6 (Fiona) - 2 messages
{
_id: 14,
body: 'Sounds like fun! I love arcade games.',
user_id: 6,
created_at: new Date('2025-01-10T18:05:00.000Z'),
updated_at: new Date('2025-01-10T18:05:00.000Z')
},
{
_id: 15,
body: 'Should I bring snacks?',
user_id: 6,
created_at: new Date('2025-01-10T18:10:00.000Z'),
updated_at: new Date('2025-01-10T18:10:00.000Z')
},
// User 7 (George) - 3 messages
{
_id: 16,
body: 'Snacks are always welcome!',
user_id: 7,
created_at: new Date('2025-01-10T18:15:00.000Z'),
updated_at: new Date('2025-01-10T18:15:00.000Z')
},
{
_id: 17,
body: 'I can bring some drinks',
user_id: 7,
created_at: new Date('2025-01-10T18:20:00.000Z'),
updated_at: new Date('2025-01-10T18:20:00.000Z')
},
{
_id: 18,
body: 'This is going to be awesome',
user_id: 7,
created_at: new Date('2025-01-10T19:00:00.000Z'),
updated_at: new Date('2025-01-10T19:00:00.000Z')
},
// User 8 (Helen) - 2 messages
{
_id: 19,
body: 'I agree! Cannot wait for the game night.',
user_id: 8,
created_at: new Date('2025-01-10T19:05:00.000Z'),
updated_at: new Date('2025-01-10T19:05:00.000Z')
},
{
_id: 20,
body: 'Should we set up a Discord call?',
user_id: 8,
created_at: new Date('2025-01-10T19:10:00.000Z'),
updated_at: new Date('2025-01-10T19:10:00.000Z')
},
// User 9 (Ian) - 3 messages
{
_id: 21,
body: 'Discord would be perfect for voice chat',
user_id: 9,
created_at: new Date('2025-01-10T19:15:00.000Z'),
updated_at: new Date('2025-01-10T19:15:00.000Z')
},
{
_id: 22,
body: 'I will create a server for us',
user_id: 9,
created_at: new Date('2025-01-10T19:20:00.000Z'),
updated_at: new Date('2025-01-10T19:20:00.000Z')
},
{
_id: 23,
body: 'Link will be shared in a few minutes',
user_id: 9,
created_at: new Date('2025-01-10T19:25:00.000Z'),
updated_at: new Date('2025-01-10T19:25:00.000Z')
},
// User 10 (Julia) - 2 messages
{
_id: 24,
body: 'Thanks Ian! You are the best.',
user_id: 10,
created_at: new Date('2025-01-10T19:30:00.000Z'),
updated_at: new Date('2025-01-10T19:30:00.000Z')
},
{
_id: 25,
body: 'See you all at 8 PM!',
user_id: 10,
created_at: new Date('2025-01-10T19:35:00.000Z'),
updated_at: new Date('2025-01-10T19:35:00.000Z')
},{
_id: 99,
body: 'You are a mean jerk, you shithead!',
user_id: 10,
created_at: new Date('2025-01-10T19:35:00.000Z'),
updated_at: new Date('2025-01-10T19:35:00.000Z')
}
]);
// Create indexes for better performance (equivalent to PostgreSQL indexes)
db.users.createIndex({ "name": 1 }, { unique: true });
db.users.createIndex({ "email": 1 }, { unique: true });
db.messages.createIndex({ "user_id": 1 });
db.messages.createIndex({ "created_at": 1 });
print("MongoDB test data setup completed successfully!");
print("Users collection: " + db.users.countDocuments());
print("Messages collection: " + db.messages.countDocuments());

View file

@ -1,221 +0,0 @@
from unittest.mock import MagicMock
import pytest
from arcade_core.errors import ToolExecutionError
from arcade_mcp_server import Context
from arcade_mcp_server.exceptions import RetryableToolError
from arcade_mongodb.tools.mongodb import aggregate_documents, count_documents, find_documents
from .conftest import TEST_MONGODB_CONNECTION_STRING
@pytest.fixture
def mock_context():
context = MagicMock(spec=Context)
context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING)
return context
@pytest.mark.asyncio
async def test_invalid_json_in_filter_dict(mock_context) -> None:
"""Test that invalid JSON in filter_dict returns a reasonable error message."""
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"status": "active",}', # Invalid JSON - trailing comma
limit=1,
)
# Check that this is a JSON validation error
error_message = str(exc_info.value)
assert "Invalid JSON in filter_dict" in error_message
# Check that the developer message contains helpful information
assert "filter_dict" in exc_info.value.developer_message
assert "JSON" in exc_info.value.additional_prompt_content
# Check that the original JSON error is in the cause chain
assert exc_info.value.__cause__ is not None
@pytest.mark.asyncio
async def test_invalid_json_in_projection(mock_context) -> None:
"""Test that invalid JSON in projection returns a reasonable error message."""
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
projection='{"name": 1, "email": 1,}', # Invalid JSON - trailing comma
limit=1,
)
# Check that this is a JSON validation error
error_message = str(exc_info.value)
assert "Invalid JSON in projection" in error_message
# Check that the error message is helpful
assert "projection" in exc_info.value.developer_message
assert "JSON" in exc_info.value.additional_prompt_content
# Check that the original JSON error is in the cause chain
assert exc_info.value.__cause__ is not None
@pytest.mark.asyncio
async def test_invalid_json_in_sort(mock_context) -> None:
"""Test that invalid JSON in sort returns a reasonable error message."""
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
sort=['{"field": "name", "direction": 1,}'], # Invalid JSON - trailing comma
limit=1,
)
# Check that this is a JSON validation error
error_message = str(exc_info.value)
assert "Invalid JSON in sort" in error_message
# Check that the error message is helpful
assert "sort" in exc_info.value.developer_message
assert "JSON" in exc_info.value.additional_prompt_content
# Check that the original JSON error is in the cause chain
assert exc_info.value.__cause__ is not None
@pytest.mark.asyncio
async def test_invalid_json_in_count_filter(mock_context) -> None:
"""Test that invalid JSON in count_documents filter returns a reasonable error message."""
with pytest.raises(RetryableToolError) as exc_info:
await count_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"status": "active",}', # Invalid JSON - trailing comma
)
# Check that this is a JSON validation error
error_message = str(exc_info.value)
assert "Invalid JSON in filter_dict" in error_message
# Check that the error message is helpful
assert "filter_dict" in exc_info.value.developer_message
assert "JSON" in exc_info.value.additional_prompt_content
# Check that the original JSON error is in the cause chain
assert exc_info.value.__cause__ is not None
@pytest.mark.asyncio
async def test_invalid_json_in_pipeline(mock_context) -> None:
"""Test that invalid JSON in aggregation pipeline returns a reasonable error message."""
with pytest.raises(RetryableToolError) as exc_info:
await aggregate_documents(
mock_context,
database_name="test_database",
collection_name="users",
pipeline=['{"$match": {"status": "active",}}'], # Invalid JSON - trailing comma
)
# Check that this is a JSON validation error
error_message = str(exc_info.value)
assert "Invalid JSON in pipeline" in error_message
# Check that the error message is helpful
assert "pipeline" in exc_info.value.developer_message
assert "JSON" in exc_info.value.additional_prompt_content
# Check that the original JSON error is in the cause chain
assert exc_info.value.__cause__ is not None
@pytest.mark.asyncio
async def test_malformed_json_string(mock_context) -> None:
"""Test various malformed JSON strings return reasonable error messages."""
test_cases = [
('{"unclosed": "string}', "Unterminated string"),
('{"missing_quotes": value}', "Expecting"),
('{missing_closing_brace: "value"}', "Expecting"),
('[{"array": "with"}, {"missing": }]', "Expecting"),
]
for invalid_json, expected_error_fragment in test_cases:
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict=invalid_json,
limit=1,
)
# Check that this is a JSON validation error
error_message = str(exc_info.value)
assert "Invalid JSON in filter_dict" in error_message
# Check that specific error details are included when expected
if expected_error_fragment:
assert (
expected_error_fragment in error_message
or expected_error_fragment in exc_info.value.developer_message
)
# Ensure helpful context is provided
assert "filter_dict" in exc_info.value.developer_message
assert "JSON" in exc_info.value.additional_prompt_content
assert "escaping" in exc_info.value.additional_prompt_content
# Check that the original JSON error is in the cause chain
assert exc_info.value.__cause__ is not None
@pytest.mark.asyncio
async def test_valid_json_does_not_error(mock_context) -> None:
"""Test that valid JSON does not raise JSON parsing errors."""
# This should not raise a JSON parsing error (might raise other errors, but not JSON-related)
try:
result = await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"status": "active"}',
projection='{"name": 1, "_id": 0}',
sort=['{"field": "name", "direction": 1}'],
limit=1,
)
# If we get here, JSON parsing succeeded
assert isinstance(result, list)
except (ToolExecutionError, RetryableToolError) as e:
# If we get an error, it should not be about JSON parsing
# Check both the outer error and any nested error
error_message = str(e)
nested_message = str(e.__cause__) if e.__cause__ else ""
assert "Invalid JSON" not in error_message
assert "Invalid JSON" not in nested_message
@pytest.mark.asyncio
async def test_duplicate_keys_are_valid_json(mock_context) -> None:
"""Test that duplicate keys in JSON are valid (Python JSON allows this)."""
# This should NOT raise a JSON parsing error because duplicate keys are valid JSON
try:
result = await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"duplicate": "key", "duplicate": "key"}', # Valid JSON - last value wins
limit=1,
)
# If we get here, JSON parsing succeeded (might get empty results, but no JSON error)
assert isinstance(result, list)
except (ToolExecutionError, RetryableToolError) as e:
# If we get an error, it should not be about JSON parsing
error_message = str(e)
nested_message = str(e.__cause__) if e.__cause__ else ""
assert "Invalid JSON" not in error_message
assert "Invalid JSON" not in nested_message

View file

@ -1,292 +0,0 @@
import json
from unittest.mock import MagicMock
import pytest
from arcade_mcp_server import Context
from arcade_mcp_server.exceptions import RetryableToolError
from arcade_mongodb.database_engine import DatabaseEngine
from arcade_mongodb.tools.mongodb import (
# UserStatus,
aggregate_documents,
count_documents,
discover_collections,
discover_databases,
find_documents,
get_collection_schema,
# update_user_status,
)
from .conftest import TEST_MONGODB_CONNECTION_STRING
@pytest.fixture
def mock_context():
context = MagicMock(spec=Context)
context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING)
return context
@pytest.mark.asyncio
async def test_discover_databases(mock_context) -> None:
databases = await discover_databases(mock_context)
assert isinstance(databases, list)
# Should not include system databases like admin, config, local
for db in databases:
assert db not in ["admin", "config", "local"]
@pytest.mark.asyncio
async def test_discover_collections(mock_context) -> None:
collections = await discover_collections(mock_context, "test_database")
assert "users" in collections
assert "messages" in collections
@pytest.mark.asyncio
async def test_get_collection_schema(mock_context) -> None:
schema_result = await get_collection_schema(
mock_context, "test_database", "users", sample_size=10
)
assert "schema" in schema_result
assert "total_documents_sampled" in schema_result
assert schema_result["total_documents_sampled"] == 10 # We have 10 users
schema = schema_result["schema"]
assert "_id" in schema
assert "name" in schema
assert "email" in schema
assert "password_hash" in schema
assert "status" in schema
assert "created_at" in schema
assert "updated_at" in schema
@pytest.mark.asyncio
async def test_find_documents_basic(mock_context) -> None:
# Find all users
result = await find_documents(
mock_context, database_name="test_database", collection_name="users", limit=10
)
assert len(result) == 10
# Parse JSON strings to check contents
docs = [json.loads(doc_str) for doc_str in result]
assert all("name" in doc for doc in docs)
assert all("email" in doc for doc in docs)
@pytest.mark.asyncio
async def test_find_documents_with_filter(mock_context) -> None:
# Find active users
result = await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"status": "active"}',
limit=10,
)
assert len(result) == 10 # All users in dump are active
docs = [json.loads(doc_str) for doc_str in result]
assert all(doc["status"] == "active" for doc in docs)
@pytest.mark.asyncio
async def test_find_documents_with_projection(mock_context) -> None:
# Find users with only name and email
result = await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
projection='{"name": 1, "email": 1, "_id": 0}',
limit=10,
)
assert len(result) == 10
docs = [json.loads(doc_str) for doc_str in result]
for doc in docs:
assert "name" in doc
assert "email" in doc
assert "_id" not in doc
assert "password_hash" not in doc
@pytest.mark.asyncio
async def test_find_documents_with_sort(mock_context) -> None:
# Find users sorted by _id descending
result = await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
sort=['{"field": "_id", "direction": -1}'],
limit=3,
)
assert len(result) == 3
docs = [json.loads(doc_str) for doc_str in result]
ids = [doc["_id"] for doc in docs]
assert ids == [10, 9, 8] # Descending order
@pytest.mark.asyncio
async def test_count_documents(mock_context) -> None:
# Count all users
count = await count_documents(
mock_context, database_name="test_database", collection_name="users"
)
assert count == 10
# Count active users
active_count = await count_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"status": "active"}',
)
assert active_count == 10
@pytest.mark.asyncio
async def test_aggregate_documents(mock_context) -> None:
# Aggregate to count users by status
pipeline = ['{"$group": {"_id": "$status", "count": {"$sum": 1}}}', '{"$sort": {"count": -1}}']
result = await aggregate_documents(
mock_context, database_name="test_database", collection_name="users", pipeline=pipeline
)
assert len(result) == 1 # Only active users
# Should be sorted by count descending
doc = json.loads(result[0])
assert doc["_id"] == "active"
assert doc["count"] == 10
@pytest.mark.asyncio
async def test_find_documents_with_skip_and_limit(mock_context) -> None:
# Test pagination
result1 = await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
sort=['{"field": "name", "direction": 1}'],
limit=2,
skip=0,
)
result2 = await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
sort=['{"field": "name", "direction": 1}'],
limit=2,
skip=2,
)
assert len(result1) == 2
assert len(result2) == 2
docs1 = [json.loads(doc_str) for doc_str in result1]
docs2 = [json.loads(doc_str) for doc_str in result2]
assert docs1[0]["name"] == "Alice"
assert docs1[1]["name"] == "Bob"
assert docs2[0]["name"] == "Charlie"
assert docs2[1]["name"] == "Diana"
@pytest.mark.asyncio
async def test_error_handling_invalid_database(mock_context) -> None:
# Test with non-existent database - should not error but return empty results
collections = await discover_collections(mock_context, "nonexistent_database")
assert collections == []
@pytest.mark.asyncio
async def test_error_handling_invalid_collection(mock_context) -> None:
# Test with non-existent collection
result = await find_documents(
mock_context,
database_name="test_database",
collection_name="nonexistent_collection",
limit=10,
)
assert result == []
@pytest.mark.asyncio
async def test_sanitize_query_params() -> None:
# Test parameter validation
with pytest.raises(RetryableToolError) as e:
DatabaseEngine.sanitize_query_params("", "users", {}, None, None, 10, 0)
assert "Database name is required" in str(e.value)
with pytest.raises(RetryableToolError) as e:
DatabaseEngine.sanitize_query_params("test_db", "", {}, None, None, 10, 0)
assert "Collection name is required" in str(e.value)
with pytest.raises(RetryableToolError) as e:
DatabaseEngine.sanitize_query_params(
"test_db", "users", {}, None, None, 2000, 0
) # Too high limit
assert "Limit is too high" in str(e.value)
# @pytest.mark.asyncio
# async def test_update_user_status_success(mock_context) -> None:
# """Test successful user status update."""
# # First, find a user to update
# users = await find_documents(
# mock_context, database_name="test_database", collection_name="users", limit=1
# )
# assert len(users) > 0
# user_doc = json.loads(users[0])
# user_id = user_doc["_id"]
# # Update user status to inactive
# result = await update_user_status(
# mock_context,
# database_name="test_database",
# collection_name="users",
# user_id=user_id,
# status=UserStatus.INACTIVE,
# )
# assert result["success"] is True
# assert result["user_id"] == user_id
# assert result["new_status"] == "inactive"
# assert result["matched_count"] == 1
# assert result["modified_count"] == 1
# # Verify the update by finding the user again
# # Convert user_id to int since the test data uses integer IDs
# user_id_int = int(user_id)
# updated_users = await find_documents(
# mock_context,
# database_name="test_database",
# collection_name="users",
# filter_dict=f'{{"_id": {user_id_int}}}',
# limit=1,
# )
# assert len(updated_users) == 1
# updated_user = json.loads(updated_users[0])
# assert updated_user["status"] == "inactive"
# @pytest.mark.asyncio
# async def test_update_user_status_user_not_found(mock_context) -> None:
# """Test updating status for non-existent user."""
# result = await update_user_status(
# mock_context,
# database_name="test_database",
# collection_name="users",
# user_id="nonexistent_user_id",
# status=UserStatus.BANNED,
# )
# assert result["success"] is False
# assert "No user found with _id" in result["message"]
# assert result["matched_count"] == 0
# assert result["modified_count"] == 0

View file

@ -1,12 +0,0 @@
#!/bin/bash
# install mongosh to load sample data
sudo apt-get update
sudo apt-get install -y wget gnupg
wget -qO - https://www.mongodb.org/static/pgp/server-6.0.asc | sudo apt-key add -
echo "deb [ arch=amd64,arm64 ] https://repo.mongodb.org/apt/ubuntu jammy/mongodb-org/6.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-6.0.list
sudo apt-get update
sudo apt-get install -y mongodb-mongosh
# Run mongodb container
docker run -d --name some-mongodb-server -p 27017:27017 mongo

View file

@ -1,248 +0,0 @@
from unittest.mock import MagicMock
import pytest
from arcade_mcp_server import Context
from arcade_mcp_server.exceptions import RetryableToolError
from arcade_mongodb.tools.mongodb import aggregate_documents, count_documents, find_documents
from .conftest import TEST_MONGODB_CONNECTION_STRING
@pytest.fixture
def mock_context():
context = MagicMock(spec=Context)
context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING)
return context
@pytest.mark.asyncio
async def test_filter_dict_blocks_set_operation(mock_context) -> None:
"""Test that $set operation in filter_dict is blocked."""
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"$set": {"status": "modified"}}', # Write operation
limit=1,
)
error_message = str(exc_info.value)
assert "Write operation '$set' not allowed in filter_dict" in error_message
assert "$set" in exc_info.value.developer_message
assert "Only read operations are allowed" in exc_info.value.developer_message
@pytest.mark.asyncio
async def test_filter_dict_blocks_update_operations(mock_context) -> None:
"""Test that various update operations in filter_dict are blocked."""
update_ops = ["$inc", "$unset", "$push", "$pull", "$rename", "$currentDate"]
for op in update_ops:
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict=f'{{"{op}": {{"field": "value"}}}}',
limit=1,
)
error_message = str(exc_info.value)
assert f"Write operation '{op}' not allowed in filter_dict" in error_message
@pytest.mark.asyncio
async def test_projection_blocks_write_operations(mock_context) -> None:
"""Test that write operations in projection are blocked."""
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
projection='{"$set": {"modified": true}, "name": 1}', # Write operation in projection
limit=1,
)
error_message = str(exc_info.value)
assert "Write operation '$set' not allowed in projection" in error_message
@pytest.mark.asyncio
async def test_sort_blocks_write_operations(mock_context) -> None:
"""Test that write operations in sort are blocked."""
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
sort=['{"field": "name", "direction": 1, "$inc": {"counter": 1}}'], # Write op in sort
limit=1,
)
error_message = str(exc_info.value)
assert "Write operation '$inc' not allowed in sort[0]" in error_message
@pytest.mark.asyncio
async def test_count_filter_blocks_write_operations(mock_context) -> None:
"""Test that write operations in count filter are blocked."""
with pytest.raises(RetryableToolError) as exc_info:
await count_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"status": "active", "$unset": {"password": ""}}', # Write operation
)
error_message = str(exc_info.value)
assert "Write operation '$unset' not allowed in filter_dict" in error_message
@pytest.mark.asyncio
async def test_aggregation_pipeline_blocks_out_stage(mock_context) -> None:
"""Test that $out stage in aggregation pipeline is blocked."""
with pytest.raises(RetryableToolError) as exc_info:
await aggregate_documents(
mock_context,
database_name="test_database",
collection_name="users",
pipeline=[
'{"$match": {"status": "active"}}',
'{"$out": "output_collection"}', # Write stage
],
)
error_message = str(exc_info.value)
assert "Write stage '$out' not allowed in pipeline" in error_message
@pytest.mark.asyncio
async def test_aggregation_pipeline_blocks_merge_stage(mock_context) -> None:
"""Test that $merge stage in aggregation pipeline is blocked."""
with pytest.raises(RetryableToolError) as exc_info:
await aggregate_documents(
mock_context,
database_name="test_database",
collection_name="users",
pipeline=[
'{"$match": {"status": "active"}}',
'{"$merge": {"into": "target_collection"}}', # Write stage
],
)
error_message = str(exc_info.value)
assert "Write stage '$merge' not allowed in pipeline" in error_message
@pytest.mark.asyncio
async def test_where_operator_blocked(mock_context) -> None:
"""Test that $where operator is blocked for security reasons."""
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"$where": "this.name == \'admin\'"}', # JavaScript execution
limit=1,
)
error_message = str(exc_info.value)
assert "JavaScript execution operator '$where' not allowed in filter_dict" in error_message
assert (
"JavaScript execution is not allowed for security reasons"
in exc_info.value.developer_message
)
@pytest.mark.asyncio
async def test_nested_write_operations_blocked(mock_context) -> None:
"""Test that nested write operations are blocked."""
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"status": "active", "nested": {"$set": {"field": "value"}}}', # Nested write op
limit=1,
)
error_message = str(exc_info.value)
assert "Write operation '$set' not allowed in filter_dict" in error_message
assert "nested.$set" in exc_info.value.developer_message # Should show the path
@pytest.mark.asyncio
async def test_valid_read_operations_allowed(mock_context) -> None:
"""Test that valid read operations are allowed."""
# These should not raise write operation errors
try:
# Test query operators
result = await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict='{"status": {"$in": ["active", "inactive"]}, "name": {"$regex": "^A"}}',
projection='{"name": 1, "email": 1, "_id": 0}',
sort=['{"field": "name", "direction": 1}'],
limit=1,
)
assert isinstance(result, list)
# Test aggregation pipeline with read-only stages
pipeline_result = await aggregate_documents(
mock_context,
database_name="test_database",
collection_name="users",
pipeline=[
'{"$match": {"status": "active"}}',
'{"$group": {"_id": "$status", "count": {"$sum": 1}}}',
'{"$sort": {"count": -1}}',
],
)
assert isinstance(pipeline_result, list)
except RetryableToolError as e:
# If we get an error, it should not be about write operations
error_message = str(e)
nested_message = str(e.__cause__) if e.__cause__ else ""
assert "Write operation" not in error_message
assert "Write stage" not in error_message
assert "Write operation" not in nested_message
assert "Write stage" not in nested_message
@pytest.mark.asyncio
async def test_array_write_operations_blocked(mock_context) -> None:
"""Test that array write operations are blocked."""
array_write_ops = ["$addToSet", "$pop", "$pull", "$push", "$pullAll"]
for op in array_write_ops:
with pytest.raises(RetryableToolError) as exc_info:
await find_documents(
mock_context,
database_name="test_database",
collection_name="users",
filter_dict=f'{{"{op}": {{"tags": "new_tag"}}}}',
limit=1,
)
error_message = str(exc_info.value)
assert f"Write operation '{op}' not allowed in filter_dict" in error_message
@pytest.mark.asyncio
async def test_aggregation_stage_content_validated(mock_context) -> None:
"""Test that content within aggregation stages is also validated for write operations."""
with pytest.raises(RetryableToolError) as exc_info:
await aggregate_documents(
mock_context,
database_name="test_database",
collection_name="users",
pipeline=[
'{"$match": {"status": "active", "$set": {"modified": true}}}' # Write op inside $match
],
)
error_message = str(exc_info.value)
assert "Write operation '$set' not allowed in pipeline[0].$match" in error_message

View file

@ -1,53 +0,0 @@
.PHONY: help
help:
@echo "🛠️ github Commands:\n"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: install
install: ## Install the uv environment and install all packages with dependencies
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras --no-sources
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv"
.PHONY: install-local
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras
@uv run pre-commit install
@echo "✅ All packages and dependencies installed via uv"
.PHONY: build
build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
uv build
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning dist directory"
rm -rf dist
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
coverage report
@echo "Generating coverage report"
coverage html
.PHONY: bump-version
bump-version: ## Bump the version in the pyproject.toml file by a patch version
@echo "🚀 Bumping version in pyproject.toml"
uv version --bump patch
.PHONY: check
check: ## Run code quality tools.
@echo "🚀 Linting code: Running pre-commit"
@uv run pre-commit run -a
@echo "🚀 Static type checking: Running mypy"
@uv run mypy --config-file=pyproject.toml

View file

@ -1,29 +0,0 @@
import sys
from typing import cast
from arcade_mcp_server import MCPApp
from arcade_mcp_server.mcp_app import TransportType
import arcade_postgres
app = MCPApp(
name="PostgreSQL",
instructions=(
"Use this server when you need to interact with PostgreSQL to help users "
"query, explore, and manage their PostgreSQL databases."
),
)
app.add_tools_from_module(arcade_postgres)
def main() -> None:
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
app.run(transport=cast(TransportType, transport), host=host, port=port)
if __name__ == "__main__":
main()

View file

@ -1,180 +0,0 @@
from typing import Any, ClassVar
from urllib.parse import urlparse
from arcade_mcp_server.exceptions import RetryableToolError
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
MAX_ROWS_RETURNED = 1000
TEST_QUERY = "SELECT 1"
class DatabaseEngine:
_instance: ClassVar[None] = None
_engines: ClassVar[dict[str, AsyncEngine]] = {}
@classmethod
async def get_instance(cls, connection_string: str) -> AsyncEngine:
parsed_url = urlparse(connection_string)
# TODO: something strange with sslmode= and friends
# query_params = parse_qs(parsed_url.query)
# query_params = {
# k: v[0] for k, v in query_params.items()
# } # assume one value allowed for each query param
async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}"
key = f"{async_connection_string}"
if key not in cls._engines:
cls._engines[key] = create_async_engine(async_connection_string)
# try a simple query to see if the connection is valid
try:
async with cls._engines[key].connect() as connection:
await connection.execute(text(TEST_QUERY))
return cls._engines[key]
except Exception:
await cls._engines[key].dispose()
# try again
try:
async with cls._engines[key].connect() as connection:
await connection.execute(text(TEST_QUERY))
return cls._engines[key]
except Exception as e:
raise RetryableToolError(
f"Connection failed: {e}",
developer_message="Connection to postgres failed.",
additional_prompt_content="Check the connection string and try again.",
) from e
@classmethod
async def get_engine(cls, connection_string: str) -> Any:
engine = await cls.get_instance(connection_string)
class ConnectionContextManager:
def __init__(self, engine: AsyncEngine) -> None:
self.engine = engine
async def __aenter__(self) -> AsyncEngine:
return self.engine
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# Connection cleanup is handled by the async context manager
pass
return ConnectionContextManager(engine)
@classmethod
async def cleanup(cls) -> None:
"""Clean up all cached engines. Call this when shutting down."""
for engine in cls._engines.values():
await engine.dispose()
cls._engines.clear()
@classmethod
def clear_cache(cls) -> None:
"""Clear the engine cache without disposing engines. Use with caution."""
cls._engines.clear()
@classmethod
def sanitize_query( # noqa: C901
cls,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> tuple[str, dict[str, Any]]:
# Remove the leading keywords from the clauses if they are present
if select_clause.strip().split(" ")[0].upper() == "SELECT":
select_clause = select_clause.strip()[6:]
if from_clause.strip().split(" ")[0].upper() == "FROM":
from_clause = from_clause.strip()[4:]
if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN":
join_clause = join_clause.strip()[4:]
if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE":
where_clause = where_clause.strip()[5:]
if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY":
group_by_clause = group_by_clause.strip()[8:]
if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY":
order_by_clause = order_by_clause.strip()[8:]
if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING":
having_clause = having_clause.strip()[6:]
first_select_word = select_clause.strip().split(" ")[0].upper()
if first_select_word in [
"INSERT",
"UPDATE",
"DELETE",
"CREATE",
"ALTER",
"DROP",
"TRUNCATE",
"REINDEX",
"VACUUM",
"ANALYZE",
"COMMENT",
]:
raise RetryableToolError(
"Only SELECT queries are allowed.",
)
if select_clause.strip() == "*":
raise RetryableToolError(
"Do not use * in the select clause. Use a comma separated list of columns you wish to return.",
)
if limit > MAX_ROWS_RETURNED:
raise RetryableToolError(
f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.",
)
if offset < 0:
raise RetryableToolError(
"Offset must be greater than or equal to 0.",
developer_message="Offset must be greater than or equal to 0.",
)
if limit <= 0:
raise RetryableToolError(
"Limit must be greater than 0.",
developer_message="Limit must be greater than 0.",
)
# Build query with identifiers directly interpolated, but use parameters for values
parts = []
if with_clause:
parts.append(f"WITH {with_clause}")
parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608
if join_clause:
parts.append(f"JOIN {join_clause}")
if where_clause:
parts.append(f"WHERE {where_clause}")
if group_by_clause:
parts.append(f"GROUP BY {group_by_clause}")
if having_clause:
parts.append(f"HAVING {having_clause}")
if order_by_clause:
parts.append(f"ORDER BY {order_by_clause}")
parts.append("LIMIT :limit OFFSET :offset")
query = " ".join(parts)
# Only use parameters for values, not identifiers
parameters = {
"limit": limit,
"offset": offset,
}
return query, parameters

View file

@ -1,300 +0,0 @@
from typing import Annotated, Any
from arcade_mcp_server import Context, tool
from arcade_mcp_server.exceptions import RetryableToolError
from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata
from sqlalchemy import inspect, text
from sqlalchemy.ext.asyncio import AsyncEngine
from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine
@tool(
requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def discover_schemas(
context: Context,
) -> list[str]:
"""Discover all the schemas in the postgres database."""
async with await DatabaseEngine.get_engine(
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
) as engine:
schemas = await _get_schemas(engine)
return schemas
@tool(
requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def discover_tables(
context: Context,
schema_name: Annotated[
str, "The database schema to discover tables in (default value: 'public')"
] = "public",
) -> list[str]:
"""Discover all the tables in the postgres database when the list of tables is not known.
ALWAYS use this tool before any other tool that requires a table name.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
) as engine:
tables = await _get_tables(engine, schema_name)
return tables
@tool(
requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def get_table_schema(
context: Context,
schema_name: Annotated[str, "The database schema to get the table schema of"],
table_name: Annotated[str, "The table to get the schema of"],
) -> list[str]:
"""
Get the schema/structure of a postgres table in the postgres database when the schema is not known, and the name of the table is provided.
This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the <DiscoverTables> tool.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
) as engine:
return await _get_table_schema(engine, schema_name, table_name)
@tool(
requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"],
metadata=ToolMetadata(
behavior=Behavior(
operations=[Operation.READ],
read_only=True,
destructive=False,
idempotent=True,
open_world=True,
),
),
)
async def execute_select_query(
context: Context,
select_clause: Annotated[
str,
"This is the part of the SQL query that comes after the SELECT keyword wish a comma separated list of columns you wish to return. Do not include the SELECT keyword.",
],
from_clause: Annotated[
str,
"This is the part of the SQL query that comes after the FROM keyword. Do not include the FROM keyword.",
],
limit: Annotated[
int,
"The maximum number of rows to return. This is the LIMIT clause of the query. Default: 100.",
] = 100,
offset: Annotated[
int, "The number of rows to skip. This is the OFFSET clause of the query. Default: 0."
] = 0,
join_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the JOIN keyword. Do not include the JOIN keyword. If no join is needed, leave this blank.",
] = None,
where_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the WHERE keyword. Do not include the WHERE keyword. If no where clause is needed, leave this blank.",
] = None,
having_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the HAVING keyword. Do not include the HAVING keyword. If no having clause is needed, leave this blank.",
] = None,
group_by_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the GROUP BY keyword. Do not include the GROUP BY keyword. If no group by clause is needed, leave this blank.",
] = None,
order_by_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the ORDER BY keyword. Do not include the ORDER BY keyword. If no order by clause is needed, leave this blank.",
] = None,
with_clause: Annotated[
str | None,
"This is the part of the SQL query that comes after the WITH keyword when basing the query on a virtual table. If no WITH clause is needed, leave this blank.",
] = None,
) -> list[str]:
"""
You have a connection to a postgres database.
Execute a SELECT query and return the results against the postgres database. No other queries (INSERT, UPDATE, DELETE, etc.) are allowed.
ONLY use this tool if you have already loaded the schema of the tables you need to query. Use the <GetTableSchema> tool to load the schema if not already known.
The final query will be constructed as follows:
SELECT {select_query_part} FROM {from_clause} JOIN {join_clause} WHERE {where_clause} HAVING {having_clause} ORDER BY {order_by_clause} LIMIT {limit} OFFSET {offset}
When running queries, follow these rules which will help avoid errors:
* Never "select *" from a table. Always select the columns you need.
* Always order your results by the most important columns first. If you aren't sure, order by the primary key.
* Always use case-insensitive queries to match strings in the query.
* Always trim strings in the query.
* Prefer LIKE queries over direct string matches or regex queries.
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
"""
async with await DatabaseEngine.get_engine(
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
) as engine:
try:
return await _execute_query(
engine,
select_clause=select_clause,
from_clause=from_clause,
limit=limit,
offset=offset,
join_clause=join_clause,
where_clause=where_clause,
having_clause=having_clause,
group_by_clause=group_by_clause,
order_by_clause=order_by_clause,
with_clause=with_clause,
)
except Exception as e:
raise RetryableToolError(
f"Query failed: {e}",
developer_message=f"Query failed with parameters: select_clause={select_clause}, from_clause={from_clause}, limit={limit}, offset={offset}, join_clause={join_clause}, where_clause={where_clause}, having_clause={having_clause}, order_by_clause={order_by_clause}, with_clause={with_clause}.",
additional_prompt_content="Load the database schema <GetTableSchema> or use the <DiscoverTables> tool to discover the tables and try again.",
retry_after_ms=10,
) from e
async def _get_schemas(engine: AsyncEngine) -> list[str]:
"""Get all the schemas in the database"""
async with engine.connect() as conn:
def get_schema_names(sync_conn: Any) -> list[str]:
return list(inspect(sync_conn).get_schema_names())
schemas: list[str] = await conn.run_sync(get_schema_names)
schemas = [schema for schema in schemas if schema != "information_schema"]
return schemas
async def _get_tables(engine: AsyncEngine, schema_name: str) -> list[str]:
"""Get all the tables in the database"""
async with engine.connect() as conn:
def get_schema_names(sync_conn: Any) -> list[str]:
return list(inspect(sync_conn).get_schema_names())
schemas: list[str] = await conn.run_sync(get_schema_names)
tables = []
for schema in schemas:
if schema == schema_name:
def get_table_names(sync_conn: Any, s: str = schema) -> list[str]:
return list(inspect(sync_conn).get_table_names(schema=s))
these_tables = await conn.run_sync(get_table_names)
tables.extend(these_tables)
tables.sort()
return tables
async def _get_table_schema(engine: AsyncEngine, schema_name: str, table_name: str) -> list[str]:
"""Get the schema of a table"""
async with engine.connect() as connection:
def get_columns(sync_conn: Any, t: str = table_name, s: str = schema_name) -> list[Any]:
return list(inspect(sync_conn).get_columns(t, s))
columns_table = await connection.run_sync(get_columns)
# Get primary key information
pk_constraint = await connection.run_sync(
lambda sync_conn: inspect(sync_conn).get_pk_constraint(table_name, schema_name)
)
primary_keys = set(pk_constraint.get("constrained_columns", []))
# Get index information
indexes = await connection.run_sync(
lambda sync_conn: inspect(sync_conn).get_indexes(table_name, schema_name)
)
indexed_columns = set()
for index in indexes:
indexed_columns.update(index.get("column_names", []))
results = []
for column in columns_table:
column_name = column["name"]
column_type = column["type"].python_type.__name__
# Build column description
description = f"{column_name}: {column_type}"
# Add primary key indicator
if column_name in primary_keys:
description += " (PRIMARY KEY)"
# Add index indicator
if column_name in indexed_columns:
description += " (INDEXED)"
results.append(description)
return results[:MAX_ROWS_RETURNED]
async def _execute_query(
engine: AsyncEngine,
select_clause: str,
from_clause: str,
limit: int,
offset: int,
join_clause: str | None,
where_clause: str | None,
having_clause: str | None,
group_by_clause: str | None,
order_by_clause: str | None,
with_clause: str | None,
) -> list[str]:
"""Execute a query and return the results."""
async with engine.connect() as connection:
query, parameters = DatabaseEngine.sanitize_query(
select_clause=select_clause,
from_clause=from_clause,
limit=limit,
offset=offset,
join_clause=join_clause,
where_clause=where_clause,
having_clause=having_clause,
group_by_clause=group_by_clause,
order_by_clause=order_by_clause,
with_clause=with_clause,
)
print(f"Query: {query}")
print(f"Parameters: {parameters}")
result = await connection.execute(text(query), parameters)
rows = result.fetchall()
results = [str(row) for row in rows]
return results[:MAX_ROWS_RETURNED]

View file

@ -1,94 +0,0 @@
import arcade_postgres
from arcade_core import ToolCatalog
from arcade_evals import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
SimilarityCritic,
tool_eval,
)
from arcade_postgres.tools.postgres import (
discover_tables,
execute_query,
get_table_schema,
)
# Evaluation rubric
rubric = EvalRubric(
fail_threshold=0.85,
warn_threshold=0.95,
)
catalog = ToolCatalog()
catalog.add_module(arcade_postgres)
@tool_eval()
def sql_eval_suite() -> EvalSuite:
suite = EvalSuite(
name="sql Tools Evaluation",
system_message=(
"You are an AI assistant with access to sql tools. "
"Use them to help the user with their tasks."
),
catalog=catalog,
rubric=rubric,
)
suite.add_case(
name="Get user by id (schema known)",
user_message="Tell me the name and email of user #1 in my database. The table 'users' has the following schema: id: int, name: str, email: str, password_hash: str, created_at: datetime, updated_at: datetime",
expected_tool_calls=[
ExpectedToolCall(
func=execute_query, args={"query": "SELECT name, email FROM users WHERE id = 1"}
)
],
rubric=rubric,
critics=[SimilarityCritic(critic_field="query", weight=1.0)],
)
suite.add_case(
name="Discover tables",
user_message="What tables are in my database?",
expected_tool_calls=[
ExpectedToolCall(func=discover_tables, args={}),
],
rubric=rubric,
)
suite.add_case(
name="Get table schema (schema provided)",
user_message="What columns are in the table 'public.users' in my database?",
expected_tool_calls=[
ExpectedToolCall(
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="schema_name", weight=0.5),
BinaryCritic(critic_field="table_name", weight=0.5),
],
)
suite.add_case(
name="Get table schema (schema not provided)",
user_message="What columns are in the table 'users' in my database?",
additional_messages=[
{"role": "user", "content": "When not provided, the schema is 'public'."}
],
expected_tool_calls=[
ExpectedToolCall(
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
),
],
rubric=rubric,
critics=[
BinaryCritic(critic_field="schema_name", weight=0.5),
BinaryCritic(critic_field="table_name", weight=0.5),
],
)
return suite

View file

@ -1,65 +0,0 @@
[build-system]
requires = [ "hatchling",]
build-backend = "hatchling.build"
[project]
name = "arcade_postgres"
version = "0.5.0"
description = "Tools to query and explore a postgres database"
requires-python = ">=3.10"
dependencies = [
"arcade-mcp-server>=1.17.0,<2.0.0",
"psycopg2-binary>=2.9.10",
"pydantic>=2.11.7",
"sqlalchemy>=2.0.41",
"psycopg2-binary>=2.9.10",
"asyncpg>=0.30.0",
"greenlet>=3.2.3",
]
[[project.authors]]
name = "evantahler"
email = "support@arcade.dev"
[project.optional-dependencies]
dev = [
"arcade-mcp[all]>=1.2.0,<2.0.0",
"pytest>=8.3.0,<8.4.0",
"pytest-cov>=4.0.0,<4.1.0",
"pytest-mock>=3.11.1,<3.12.0",
"pytest-asyncio>=0.24.0,<0.25.0",
"mypy>=1.5.1,<1.6.0",
"pre-commit>=3.4.0,<3.5.0",
"tox>=4.11.1,<4.12.0",
"ruff>=0.7.4,<0.8.0",
]
[project.scripts]
arcade-postgres = "arcade_postgres.__main__:main"
arcade_postgres = "arcade_postgres.__main__:main"
# Use local path sources for arcade libs when working locally
[tool.uv.sources]
arcade-mcp = { path = "../../", editable = true }
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
[tool.mypy]
files = [ "arcade_postgres/**/*.py",]
python_version = "3.10"
disallow_untyped_defs = "True"
disallow_any_unimported = "True"
no_implicit_optional = "True"
check_untyped_defs = "True"
warn_return_any = "True"
warn_unused_ignores = "True"
show_error_codes = "True"
ignore_missing_imports = "True"
[tool.pytest.ini_options]
testpaths = [ "tests",]
asyncio_default_fixture_loop_scope = "function"
[tool.coverage.report]
skip_empty = true
[tool.hatch.build.targets.wheel]
packages = [ "arcade_postgres",]

View file

@ -1,399 +0,0 @@
DROP TABLE IF EXISTS "public"."messages";
-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup.
-- Sequence and defined type
CREATE SEQUENCE IF NOT EXISTS messages_id_seq;
-- Table Definition
CREATE TABLE "public"."messages" (
"id" int4 NOT NULL DEFAULT nextval('messages_id_seq'::regclass),
"body" text NOT NULL,
"user_id" int4 NOT NULL,
"created_at" timestamp NOT NULL DEFAULT now(),
"updated_at" timestamp NOT NULL DEFAULT now(),
PRIMARY KEY ("id")
);
DROP TABLE IF EXISTS "public"."users";
-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup.
-- Sequence and defined type
CREATE SEQUENCE IF NOT EXISTS users_id_seq;
-- Table Definition
CREATE TABLE "public"."users" (
"id" int4 NOT NULL DEFAULT nextval('users_id_seq'::regclass),
"name" varchar(256) NOT NULL,
"email" text NOT NULL,
"password_hash" text NOT NULL,
"created_at" timestamp NOT NULL DEFAULT now(),
"updated_at" timestamp NOT NULL DEFAULT now(),
"status" varchar,
PRIMARY KEY ("id")
);
INSERT INTO "public"."messages" (
"id",
"body",
"user_id",
"created_at",
"updated_at"
)
VALUES -- User 1 (Alice) - 3 messages
(
1,
'Hello everyone!',
1,
'2025-01-10 10:00:00.000000',
'2025-01-10 10:00:00.000000'
),
(
2,
'How is everyone doing today?',
1,
'2025-01-10 11:30:00.000000',
'2025-01-10 11:30:00.000000'
),
(
3,
'Great to see you all here!',
1,
'2025-01-10 14:15:00.000000',
'2025-01-10 14:15:00.000000'
),
-- User 2 (Bob) - 2 messages
(
4,
'Hi Alice! Doing well, thanks for asking.',
2,
'2025-01-10 11:35:00.000000',
'2025-01-10 11:35:00.000000'
),
(
5,
'Anyone up for a game later?',
2,
'2025-01-10 16:20:00.000000',
'2025-01-10 16:20:00.000000'
),
-- User 3 (Charlie) - 3 messages
(
6,
'Count me in for the game!',
3,
'2025-01-10 16:25:00.000000',
'2025-01-10 16:25:00.000000'
),
(
7,
'What time works for everyone?',
3,
'2025-01-10 16:30:00.000000',
'2025-01-10 16:30:00.000000'
),
(
8,
'I can play around 8 PM',
3,
'2025-01-10 17:00:00.000000',
'2025-01-10 17:00:00.000000'
),
-- User 4 (Diana) - 2 messages
(
9,
'8 PM works for me too!',
4,
'2025-01-10 17:05:00.000000',
'2025-01-10 17:05:00.000000'
),
(
10,
'What game should we play?',
4,
'2025-01-10 17:10:00.000000',
'2025-01-10 17:10:00.000000'
),
-- User 5 (Evan) - 3 messages
(
11,
'I suggest we try the new arcade game!',
5,
'2025-01-10 17:15:00.000000',
'2025-01-10 17:15:00.000000'
),
(
12,
'It has great multiplayer features',
5,
'2025-01-10 17:20:00.000000',
'2025-01-10 17:20:00.000000'
),
(
13,
'Perfect timing for a weekend session',
5,
'2025-01-10 18:00:00.000000',
'2025-01-10 18:00:00.000000'
),
-- User 6 (Fiona) - 2 messages
(
14,
'Sounds like fun! I love arcade games.',
6,
'2025-01-10 18:05:00.000000',
'2025-01-10 18:05:00.000000'
),
(
15,
'Should I bring snacks?',
6,
'2025-01-10 18:10:00.000000',
'2025-01-10 18:10:00.000000'
),
-- User 7 (George) - 3 messages
(
16,
'Snacks are always welcome!',
7,
'2025-01-10 18:15:00.000000',
'2025-01-10 18:15:00.000000'
),
(
17,
'I can bring some drinks',
7,
'2025-01-10 18:20:00.000000',
'2025-01-10 18:20:00.000000'
),
(
18,
'This is going to be awesome',
7,
'2025-01-10 19:00:00.000000',
'2025-01-10 19:00:00.000000'
),
-- User 8 (Helen) - 2 messages
(
19,
'I agree! Cannot wait for the game night.',
8,
'2025-01-10 19:05:00.000000',
'2025-01-10 19:05:00.000000'
),
(
20,
'Should we set up a Discord call?',
8,
'2025-01-10 19:10:00.000000',
'2025-01-10 19:10:00.000000'
),
-- User 9 (Ian) - 3 messages
(
21,
'Discord would be perfect for voice chat',
9,
'2025-01-10 19:15:00.000000',
'2025-01-10 19:15:00.000000'
),
(
22,
'I will create a server for us',
9,
'2025-01-10 19:20:00.000000',
'2025-01-10 19:20:00.000000'
),
(
23,
'Link will be shared in a few minutes',
9,
'2025-01-10 19:25:00.000000',
'2025-01-10 19:25:00.000000'
),
-- User 10 (Julia) - 2 messages
(
24,
'Thanks Ian! You are the best.',
10,
'2025-01-10 19:30:00.000000',
'2025-01-10 19:30:00.000000'
),
(
25,
'See you all at 8 PM!',
10,
'2025-01-10 19:35:00.000000',
'2025-01-10 19:35:00.000000'
),
-- Additional messages for Evan (user_id 5) - 10 more messages
(
26,
'Just finished setting up the game server!',
5,
'2025-01-10 20:00:00.000000',
'2025-01-10 20:00:00.000000'
),
(
27,
'Everyone should be able to connect now',
5,
'2025-01-10 20:05:00.000000',
'2025-01-10 20:05:00.000000'
),
(
28,
'I added some custom maps too',
5,
'2025-01-10 20:10:00.000000',
'2025-01-10 20:10:00.000000'
),
(
29,
'The graphics look amazing on this new version',
5,
'2025-01-10 20:15:00.000000',
'2025-01-10 20:15:00.000000'
),
(
30,
'Hope you all enjoy the new features',
5,
'2025-01-10 20:20:00.000000',
'2025-01-10 20:20:00.000000'
),
(
31,
'I also set up a leaderboard system',
5,
'2025-01-10 20:25:00.000000',
'2025-01-10 20:25:00.000000'
),
(
32,
'We can track high scores now',
5,
'2025-01-10 20:30:00.000000',
'2025-01-10 20:30:00.000000'
),
(
33,
'The game supports up to 8 players simultaneously',
5,
'2025-01-10 20:35:00.000000',
'2025-01-10 20:35:00.000000'
),
(
34,
'I tested it earlier and it runs smoothly',
5,
'2025-01-10 20:40:00.000000',
'2025-01-10 20:40:00.000000'
),
(
35,
'Cannot wait to see everyone online tonight!',
5,
'2025-01-10 20:45:00.000000',
'2025-01-10 20:45:00.000000'
);
INSERT INTO "public"."users" (
"id",
"name",
"email",
"password_hash",
"created_at",
"updated_at",
"status"
)
VALUES (
1,
'Alice',
'alice@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
'2024-09-01 20:49:38.759432',
'2024-09-02 03:49:39.927',
'active'
),
(
2,
'Bob',
'bob@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
'2024-09-02 17:49:23.377425',
'2024-09-02 17:49:23.377425',
'active'
),
(
3,
'Charlie',
'charlie@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
'2024-09-03 10:30:15.123456',
'2024-09-03 10:30:15.123456',
'active'
),
(
4,
'Diana',
'diana@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123',
'2024-09-04 14:20:30.654321',
'2024-09-04 14:20:30.654321',
'active'
),
(
5,
'Evan',
'evan@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456',
'2024-09-05 09:15:45.987654',
'2024-09-05 09:15:45.987654',
'active'
),
(
6,
'Fiona',
'fiona@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789',
'2024-09-06 16:45:12.345678',
'2024-09-06 16:45:12.345678',
'active'
),
(
7,
'George',
'george@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012',
'2024-09-07 11:30:25.876543',
'2024-09-07 11:30:25.876543',
'active'
),
(
8,
'Helen',
'helen@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345',
'2024-09-08 13:25:40.234567',
'2024-09-08 13:25:40.234567',
'active'
),
(
9,
'Ian',
'ian@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678',
'2024-09-09 08:40:55.765432',
'2024-09-09 08:40:55.765432',
'active'
),
(
10,
'Julia',
'julia@example.com',
'$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901',
'2024-09-10 15:55:18.123456',
'2024-09-10 15:55:18.123456',
'active'
);
ALTER TABLE "public"."messages"
ADD FOREIGN KEY ("user_id") REFERENCES "public"."users"("id");
-- set pk to 11
ALTER SEQUENCE users_id_seq RESTART WITH 11;
-- Indices
CREATE UNIQUE INDEX name_idx ON public.users USING btree (name);
CREATE UNIQUE INDEX email_idx ON public.users USING btree (email);
DROP INDEX IF EXISTS users_email_unique;
CREATE UNIQUE INDEX users_email_unique ON public.users USING btree (email);

View file

@ -1,188 +0,0 @@
import os
from os import environ
from unittest.mock import MagicMock
import pytest
import pytest_asyncio
from arcade_mcp_server import Context
from arcade_mcp_server.exceptions import RetryableToolError
from arcade_postgres.tools.postgres import (
DatabaseEngine,
discover_schemas,
discover_tables,
execute_select_query,
get_table_schema,
)
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
POSTGRES_DATABASE_CONNECTION_STRING = (
environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING")
or "postgresql://postgres@localhost:5432/postgres"
)
@pytest.fixture
def mock_context():
context = MagicMock(spec=Context)
context.get_secret = MagicMock(return_value=POSTGRES_DATABASE_CONNECTION_STRING)
return context
# before the tests, restore the database from the dump
@pytest_asyncio.fixture(autouse=True)
async def restore_database():
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
engine = create_async_engine(
POSTGRES_DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split(
"?"
)[0]
)
async with engine.connect() as c:
queries = f.read().split(";")
await c.execute(text("BEGIN"))
for query in queries:
if query.strip():
await c.execute(text(query))
await c.commit()
await engine.dispose()
@pytest_asyncio.fixture(autouse=True)
async def cleanup_engines():
"""Clean up database engines after each test to prevent connection leaks."""
yield
# Clean up all cached engines after each test
await DatabaseEngine.cleanup()
@pytest.mark.asyncio
async def test_discover_schemas(mock_context) -> None:
assert await discover_schemas(mock_context) == ["public"]
@pytest.mark.asyncio
async def test_discover_tables(mock_context) -> None:
assert await discover_tables(mock_context) == ["messages", "users"]
@pytest.mark.asyncio
async def test_get_table_schema(mock_context) -> None:
assert await get_table_schema(mock_context, "public", "users") == [
"id: int (PRIMARY KEY)",
"name: str (INDEXED)",
"email: str (INDEXED)",
"password_hash: str",
"created_at: datetime",
"updated_at: datetime",
"status: str",
]
assert await get_table_schema(mock_context, "public", "messages") == [
"id: int (PRIMARY KEY)",
"body: str",
"user_id: int",
"created_at: datetime",
"updated_at: datetime",
]
@pytest.mark.asyncio
async def test_execute_select_query(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 1",
) == [
"(1, 'Alice', 'alice@example.com')",
]
assert await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
order_by_clause="id",
limit=1,
offset=1,
) == [
"(2, 'Bob', 'bob@example.com')",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_keywords(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="SELECT id, name, email",
from_clause="FROM users",
limit=1,
) == [
"(1, 'Alice', 'alice@example.com')",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_join(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="u.id, u.name, u.email, m.id, m.body",
from_clause="users u",
join_clause="messages m ON u.id = m.user_id",
limit=1,
) == [
"(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_group_by(mock_context) -> None:
assert await execute_select_query(
mock_context,
select_clause="u.name, COUNT(m.id) AS message_count",
from_clause="messages m",
join_clause="users u ON m.user_id = u.id",
group_by_clause="u.name",
order_by_clause="message_count DESC",
limit=2,
) == [
"('Evan', 13)",
"('Alice', 3)",
]
@pytest.mark.asyncio
async def test_execute_select_query_with_no_results(mock_context) -> None:
# does not raise an error
assert (
await execute_select_query(
mock_context,
select_clause="id, name, email",
from_clause="users",
where_clause="id = 9999999999",
)
== []
)
@pytest.mark.asyncio
async def test_execute_select_query_with_problem(mock_context) -> None:
# 'foo' is not a valid id
with pytest.raises(RetryableToolError) as e:
await execute_select_query(
mock_context,
select_clause="*",
from_clause="users",
where_clause="id = 'foo'",
)
assert "Do not use * in the select clause" in str(e.value)
@pytest.mark.asyncio
async def test_execute_select_query_rejects_non_select(mock_context) -> None:
with pytest.raises(RetryableToolError) as e:
await execute_select_query(
mock_context,
select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
from_clause="users",
)
assert "Only SELECT queries are allowed" in str(e.value)

View file

@ -1,15 +0,0 @@
#!/bin/bash
# Run PostgreSQL container
docker run -d --name some-postgres-server -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres:latest
# Wait for PostgreSQL to be ready
echo "Waiting for PostgreSQL to be ready..."
for i in {1..30}; do
if docker exec some-postgres-server pg_isready -U postgres > /dev/null 2>&1; then
echo "PostgreSQL is ready!"
break
fi
echo "Waiting... ($i/30)"
sleep 1
done

View file

@ -1,18 +0,0 @@
files: ^.*/zendesk/.*
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.4.0"
hooks:
- id: check-case-conflict
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.7
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View file

@ -1,46 +0,0 @@
target-version = "py310"
line-length = 100
fix = true
[lint]
select = [
# flake8-2020
"YTT",
# flake8-bandit
"S",
# flake8-bugbear
"B",
# flake8-builtins
"A",
# flake8-comprehensions
"C4",
# flake8-debugger
"T10",
# flake8-simplify
"SIM",
# isort
"I",
# mccabe
"C90",
# pycodestyle
"E", "W",
# pyflakes
"F",
# pygrep-hooks
"PGH",
# pyupgrade
"UP",
# ruff
"RUF",
# tryceratops
"TRY",
]
ignore = ["C901"]
[lint.per-file-ignores]
"**/tests/*" = ["S101"]
[format]
preview = true
skip-magic-trailing-comma = false

View file

@ -1,55 +0,0 @@
.PHONY: help
help:
@echo "🛠️ github Commands:\n"
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: install
install: ## Install the uv environment and install all packages with dependencies
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras --no-sources
@if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi
@echo "✅ All packages and dependencies installed via uv"
.PHONY: install-local
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
@echo "🚀 Creating virtual environment and installing all packages using uv"
@uv sync --active --all-extras
@if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi
@echo "✅ All packages and dependencies installed via uv"
.PHONY: build
build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
uv build
.PHONY: clean-build
clean-build: ## clean build artifacts
@echo "🗑️ Cleaning dist directory"
rm -rf dist
.PHONY: test
test: ## Test the code with pytest
@echo "🚀 Testing code: Running pytest"
@uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
.PHONY: coverage
coverage: ## Generate coverage report
@echo "coverage report"
@uv run --no-sources coverage report
@echo "Generating coverage report"
@uv run --no-sources coverage html
.PHONY: bump-version
bump-version: ## Bump the version in the pyproject.toml file by a patch version
@echo "🚀 Bumping version in pyproject.toml"
uv version --no-sources --bump patch
.PHONY: check
check: ## Run code quality tools.
@if [ -f .pre-commit-config.yaml ]; then\
echo "🚀 Linting code: Running pre-commit";\
uv run --no-sources pre-commit run -a;\
fi
@echo "🚀 Static type checking: Running mypy"
@uv run --no-sources mypy --config-file=pyproject.toml

View file

@ -1,15 +0,0 @@
from arcade_zendesk.tools import (
add_ticket_comment,
get_ticket_comments,
list_tickets,
mark_ticket_solved,
search_articles,
)
__all__ = [
"add_ticket_comment",
"get_ticket_comments",
"list_tickets",
"mark_ticket_solved",
"search_articles",
]

Some files were not shown because too many files have changed in this diff Show more