Remove toolkits (#784)
This commit is contained in:
parent
bcee0f556f
commit
830480de83
117 changed files with 5 additions and 12030 deletions
36
.github/actions/setup-uv-env/action.yml
vendored
36
.github/actions/setup-uv-env/action.yml
vendored
|
|
@ -6,22 +6,6 @@ inputs:
|
|||
required: false
|
||||
description: "The python version to use"
|
||||
default: "3.11"
|
||||
is-toolkit:
|
||||
required: false
|
||||
description: "Whether this is a toolkit package"
|
||||
default: "false"
|
||||
is-contrib:
|
||||
required: false
|
||||
description: "Whether this is a contrib package"
|
||||
default: "false"
|
||||
is-lib:
|
||||
required: false
|
||||
description: "Whether this is a library package"
|
||||
default: "false"
|
||||
working-directory:
|
||||
required: false
|
||||
description: "Working directory for the installation (used for toolkits)"
|
||||
default: "."
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
|
|
@ -29,26 +13,8 @@ runs:
|
|||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Install toolkit dependencies
|
||||
if: inputs.is-toolkit == 'true'
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
echo "Installing dependencies for ${{ inputs.working-directory }}"
|
||||
make install-local
|
||||
shell: bash
|
||||
|
||||
- name: Install contrib dependencies
|
||||
if: inputs.is-contrib == 'true'
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
run: |
|
||||
echo "Installing dependencies for ${{ inputs.working-directory }}"
|
||||
make install
|
||||
shell: bash
|
||||
|
||||
- name: Install libs dependencies
|
||||
if: inputs.is-toolkit != 'true'
|
||||
- name: Install dependencies
|
||||
run: uv sync --extra all --extra dev
|
||||
shell: bash
|
||||
|
|
|
|||
|
|
@ -83,14 +83,8 @@ jobs:
|
|||
uses: ./.github/actions/setup-uv-env
|
||||
with:
|
||||
python-version: "3.10"
|
||||
is-toolkit: ${{ startsWith(matrix.package, 'toolkits/') }}
|
||||
is-contrib: ${{ startsWith(matrix.package, 'contrib/') }}
|
||||
is-lib: ${{ startsWith(matrix.package, 'libs/') }}
|
||||
working-directory: ${{ matrix.package }}
|
||||
|
||||
- name: Run tests
|
||||
# Skip tests for toolkits - tests are run on every PR commit for toolkits
|
||||
if: ${{ !startsWith(matrix.package, 'toolkits/') }}
|
||||
working-directory: ${{ matrix.package }}
|
||||
run: |
|
||||
# Run tests if they exist
|
||||
|
|
|
|||
132
.github/workflows/test-toolkits.yml
vendored
132
.github/workflows/test-toolkits.yml
vendored
|
|
@ -1,132 +0,0 @@
|
|||
name: Test Toolkits
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
env:
|
||||
ARCADE_USAGE_TRACKING: "0"
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
toolkits_with_gha_secrets: ${{ steps.load_toolkits.outputs.toolkits_with_gha_secrets }}
|
||||
toolkits_without_gha_secrets: ${{ steps.load_toolkits.outputs.toolkits_without_gha_secrets }}
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: determine toolkits with and without GHA secrets
|
||||
id: load_toolkits
|
||||
run: |
|
||||
# Find all directories in toolkits/ that have a pyproject.toml
|
||||
TOOLKITS=$(find toolkits -maxdepth 1 -type d -not -name "toolkits" -exec test -f {}/pyproject.toml \; -exec basename {} \; | jq -R -s -c 'split("\n")[:-1]')
|
||||
TOOLKITS_WITH_GHA_SECRETS='["postgres", "clickhouse", "mongodb"]'
|
||||
TOOLKITS_WITHOUT_GHA_SECRETS=$(echo "$TOOLKITS" | jq -c --argjson with "$TOOLKITS_WITH_GHA_SECRETS" '[.[] | select(. as $t | $with | index($t) | not)]')
|
||||
echo "Found toolkits: $TOOLKITS"
|
||||
echo "Found toolkits without GHA secrets: $TOOLKITS_WITHOUT_GHA_SECRETS"
|
||||
echo "Found toolkits with GHA secrets: $TOOLKITS_WITH_GHA_SECRETS"
|
||||
echo "toolkits_without_gha_secrets=$TOOLKITS_WITHOUT_GHA_SECRETS" >> $GITHUB_OUTPUT
|
||||
echo "toolkits_with_gha_secrets=$TOOLKITS_WITH_GHA_SECRETS" >> $GITHUB_OUTPUT
|
||||
|
||||
test-toolkits:
|
||||
needs: setup
|
||||
name: test-toolkits (${{ matrix.toolkit }}, ${{ matrix.os }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
toolkit: ${{ fromJson(needs.setup.outputs.toolkits_without_gha_secrets) }}
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up the environment
|
||||
uses: ./.github/actions/setup-uv-env
|
||||
with:
|
||||
is-toolkit: "true"
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
|
||||
- name: Install toolkit dependencies
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
shell: bash
|
||||
run: uv pip install -e ".[dev]"
|
||||
|
||||
- name: Check toolkit
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
shell: bash
|
||||
run: |
|
||||
uv run --active pre-commit run -a
|
||||
uv run --active mypy --config-file=pyproject.toml
|
||||
|
||||
- name: Test stand-alone toolkits (no secrets)
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
shell: bash
|
||||
run: |
|
||||
# Run pytest and capture exit code
|
||||
uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$?
|
||||
|
||||
if [ "${EXIT_CODE:-0}" -eq 5 ]; then
|
||||
echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..."
|
||||
exit 0
|
||||
elif [ "${EXIT_CODE:-0}" -ne 0 ]; then
|
||||
exit ${EXIT_CODE}
|
||||
fi
|
||||
|
||||
test-toolkits-with-gha-secrets:
|
||||
needs: setup
|
||||
# Linux-only: these toolkits bootstrap local DBs via docker/apt in tests/test_setup.sh.
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
toolkit: ${{ fromJson(needs.setup.outputs.toolkits_with_gha_secrets) }}
|
||||
fail-fast: true
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up the environment
|
||||
uses: ./.github/actions/setup-uv-env
|
||||
with:
|
||||
is-toolkit: "true"
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
|
||||
- name: Install toolkit dependencies
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
run: uv pip install -e ".[dev]"
|
||||
|
||||
- name: Check toolkit
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
run: |
|
||||
uv run --active pre-commit run -a
|
||||
uv run --active mypy --config-file=pyproject.toml
|
||||
|
||||
- name: Test stand-alone toolkits (with secrets)
|
||||
if: |
|
||||
!github.event.pull_request.head.repo.fork
|
||||
working-directory: toolkits/${{ matrix.toolkit }}
|
||||
env:
|
||||
TEST_POSTGRES_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_POSTGRES_DATABASE_CONNECTION_STRING }} # TODO: dynamically only load the `TEST_${{ matrix.toolkit }}_DATABASE_CONNECTION_STRING secret`
|
||||
TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING }}
|
||||
TEST_MONGODB_CONNECTION_STRING: ${{ secrets.TEST_MONGODB_CONNECTION_STRING }}
|
||||
run: |
|
||||
# If there's a custom test_setup.sh file, run it
|
||||
if [ -f tests/test_setup.sh ]; then
|
||||
echo "Running custom test setup for ${{ matrix.toolkit }}..."
|
||||
./tests/test_setup.sh
|
||||
fi
|
||||
|
||||
# Run pytest and capture exit code
|
||||
uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$?
|
||||
|
||||
if [ "${EXIT_CODE:-0}" -eq 5 ]; then
|
||||
echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..."
|
||||
exit 0
|
||||
elif [ "${EXIT_CODE:-0}" -ne 0 ]; then
|
||||
exit ${EXIT_CODE}
|
||||
fi
|
||||
|
|
@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||
|
||||
## What This Is
|
||||
|
||||
Arcade MCP is a Python tool-calling platform for building MCP (Model Context Protocol) servers. It's a monorepo containing 5 interdependent libraries, 30+ prebuilt toolkit integrations, and a CLI.
|
||||
Arcade MCP is a Python tool-calling platform for building MCP (Model Context Protocol) servers. It's a monorepo containing 5 interdependent libraries and a CLI.
|
||||
|
||||
## Commands
|
||||
|
||||
|
|
@ -15,7 +15,6 @@ Arcade MCP is a Python tool-calling platform for building MCP (Model Context Pro
|
|||
| Run a single test | `uv run pytest libs/tests/core/test_toolkit.py::TestClass::test_method` |
|
||||
| Lint + type check | `make check` (pre-commit + mypy) |
|
||||
| Build all wheels | `make build` |
|
||||
| Run toolkit tests | `make test-toolkits` |
|
||||
|
||||
Package manager is **uv** — always use `uv run` to execute Python commands, never bare `pip` or `python`. Python 3.10+. Build system is Hatchling.
|
||||
|
||||
|
|
|
|||
117
Makefile
117
Makefile
|
|
@ -6,37 +6,6 @@ install: ## Install the uv environment and all packages with dependencies
|
|||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv workspace"
|
||||
|
||||
.PHONY: install-toolkits
|
||||
install-toolkits: ## Install dependencies for all toolkits
|
||||
@echo "🚀 Installing dependencies for all toolkits"
|
||||
@failed=0; \
|
||||
successful=0; \
|
||||
for dir in toolkits/*/ ; do \
|
||||
if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \
|
||||
echo "📦 Installing dependencies for $$dir"; \
|
||||
if (cd $$dir && uv pip install -e ".[dev]"); then \
|
||||
successful=$$((successful + 1)); \
|
||||
else \
|
||||
echo "❌ Failed to install dependencies for $$dir"; \
|
||||
failed=$$((failed + 1)); \
|
||||
fi; \
|
||||
else \
|
||||
echo "⚠️ Skipping $$dir (no pyproject.toml found)"; \
|
||||
fi; \
|
||||
done; \
|
||||
echo ""; \
|
||||
echo "📊 Installation Summary:"; \
|
||||
echo " ✅ Successful: $$successful toolkits"; \
|
||||
echo " ❌ Failed: $$failed toolkits"; \
|
||||
if [ $$failed -gt 0 ]; then \
|
||||
echo ""; \
|
||||
echo "⚠️ Some toolkit installations failed. Check the output above for details."; \
|
||||
exit 1; \
|
||||
else \
|
||||
echo ""; \
|
||||
echo "🎉 All toolkit dependencies installed successfully!"; \
|
||||
fi
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@echo "🚀 Linting code: Running pre-commit"
|
||||
|
|
@ -56,18 +25,6 @@ check-libs: ## Run code quality tools for each lib package
|
|||
(cd $$lib && uv run mypy . || true); \
|
||||
done
|
||||
|
||||
.PHONY: check-toolkits
|
||||
check-toolkits: ## Run code quality tools for each toolkit that has a Makefile
|
||||
@echo "🚀 Running 'make check' in each toolkit with a Makefile"
|
||||
@for dir in toolkits/*/ ; do \
|
||||
if [ -f "$$dir/Makefile" ]; then \
|
||||
echo "🛠️ Checking toolkit $$dir"; \
|
||||
(cd "$$dir" && uv run pre-commit run -a && uv run mypy --config-file=pyproject.toml); \
|
||||
else \
|
||||
echo "🛠️ Skipping toolkit $$dir (no Makefile found)"; \
|
||||
fi; \
|
||||
done
|
||||
|
||||
.PHONY: test
|
||||
test: install ## Test the code with pytest
|
||||
@echo "🚀 Testing libs: Running pytest"
|
||||
|
|
@ -81,15 +38,6 @@ test-libs: ## Test each lib package individually
|
|||
(cd $$lib && uv run pytest -W ignore -v || true); \
|
||||
done
|
||||
|
||||
.PHONY: test-toolkits
|
||||
test-toolkits: ## Iterate over all toolkits and run pytest on each one
|
||||
@echo "🚀 Testing code in toolkits: Running pytest"
|
||||
@for dir in toolkits/*/ ; do \
|
||||
toolkit_name=$$(basename "$$dir"); \
|
||||
echo "🧪 Testing $$toolkit_name toolkit"; \
|
||||
(cd $$dir && uv run pytest -W ignore -v --cov=arcade_$$toolkit_name --cov-report=xml || exit 1); \
|
||||
done
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
|
|
@ -107,38 +55,6 @@ build: clean-build ## Build wheel files using uv
|
|||
fi; \
|
||||
done
|
||||
|
||||
.PHONY: build-toolkits
|
||||
build-toolkits: ## Build wheel files for all toolkits
|
||||
@echo "🚀 Creating wheel files for all toolkits"
|
||||
@failed=0; \
|
||||
successful=0; \
|
||||
for dir in toolkits/*/ ; do \
|
||||
if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \
|
||||
toolkit_name=$$(basename "$$dir"); \
|
||||
echo "🛠️ Building toolkit $$toolkit_name"; \
|
||||
if (cd $$dir && uv build); then \
|
||||
successful=$$((successful + 1)); \
|
||||
else \
|
||||
echo "❌ Failed to build toolkit $$toolkit_name"; \
|
||||
failed=$$((failed + 1)); \
|
||||
fi; \
|
||||
else \
|
||||
echo "⚠️ Skipping $$dir (no pyproject.toml found)"; \
|
||||
fi; \
|
||||
done; \
|
||||
echo ""; \
|
||||
echo "📊 Build Summary:"; \
|
||||
echo " ✅ Successful: $$successful toolkits"; \
|
||||
echo " ❌ Failed: $$failed toolkits"; \
|
||||
if [ $$failed -gt 0 ]; then \
|
||||
echo ""; \
|
||||
echo "⚠️ Some toolkit builds failed. Check the output above for details."; \
|
||||
exit 1; \
|
||||
else \
|
||||
echo ""; \
|
||||
echo "🎉 All toolkit wheels built successfully!"; \
|
||||
fi
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning build artifacts"
|
||||
|
|
@ -161,30 +77,19 @@ build-and-publish: build publish ## Build and publish.
|
|||
|
||||
.PHONY: docker
|
||||
docker: ## Build and run the Docker container
|
||||
@echo "🚀 Building lib packages and toolkit wheels..."
|
||||
@echo "🚀 Building lib packages..."
|
||||
@make full-dist
|
||||
@echo "🚀 Building Docker image"
|
||||
@cd docker && make docker-build
|
||||
@cd docker && make docker-run
|
||||
|
||||
.PHONY: docker-base
|
||||
docker-base: ## Build and run the Docker container
|
||||
@echo "🚀 Building lib packages and toolkit wheels..."
|
||||
@make full-dist
|
||||
@echo "🚀 Building Docker image"
|
||||
@cd docker && INSTALL_TOOLKITS=false make docker-build
|
||||
@cd docker && INSTALL_TOOLKITS=false make docker-run
|
||||
|
||||
.PHONY: publish-ghcr
|
||||
publish-ghcr: ## Publish to the GHCR
|
||||
# Publish the base image - ghcr.io/arcadeai/worker-base
|
||||
@cd docker && INSTALL_TOOLKITS=false make publish-ghcr
|
||||
# Publish the image with toolkits - ghcr.io/arcadeai/worker
|
||||
@cd docker && INSTALL_TOOLKITS=true make publish-ghcr
|
||||
@cd docker && make publish-ghcr
|
||||
|
||||
.PHONY: full-dist
|
||||
full-dist: clean-dist ## Build all projects and copy wheels to ./dist
|
||||
@echo "🛠️ Building a full distribution with lib packages and toolkits"
|
||||
@echo "🛠️ Building a full distribution with lib packages"
|
||||
|
||||
@echo "🛠️ Building all lib packages and copying wheels to ./dist"
|
||||
@mkdir -p dist
|
||||
|
|
@ -198,16 +103,6 @@ full-dist: clean-dist ## Build all projects and copy wheels to ./dist
|
|||
@uv build
|
||||
@rm -f dist/*.tar.gz
|
||||
|
||||
@echo "🛠️ Building all toolkit packages and copying wheels to ./dist"
|
||||
@for dir in toolkits/*/ ; do \
|
||||
if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \
|
||||
toolkit_name=$$(basename "$$dir"); \
|
||||
echo "🛠️ Building toolkit $$toolkit_name wheel..."; \
|
||||
(cd $$dir && uv build); \
|
||||
cp $$dir/dist/*.whl dist/; \
|
||||
fi; \
|
||||
done
|
||||
|
||||
.PHONY: clean-dist
|
||||
clean-dist: ## Clean all built distributions
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
|
|
@ -216,12 +111,6 @@ clean-dist: ## Clean all built distributions
|
|||
@for lib in libs/arcade*/ ; do \
|
||||
rm -rf "$$lib"/dist; \
|
||||
done
|
||||
@echo "🗑️ Cleaning toolkits/*/dist directory"
|
||||
@for toolkit_dir in toolkits/*; do \
|
||||
if [ -d "$$toolkit_dir" ]; then \
|
||||
rm -rf "$$toolkit_dir"/dist; \
|
||||
fi; \
|
||||
done
|
||||
|
||||
.PHONY: setup
|
||||
setup: ## Run uv environment setup script
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
files: ^arcade_brightdata/.*
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: "v4.4.0"
|
||||
hooks:
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
target-version = "py310"
|
||||
line-length = 100
|
||||
fix = true
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
# flake8-2020
|
||||
"YTT",
|
||||
# flake8-bandit
|
||||
"S",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-builtins
|
||||
"A",
|
||||
# flake8-comprehensions
|
||||
"C4",
|
||||
# flake8-debugger
|
||||
"T10",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# mccabe
|
||||
"C90",
|
||||
# pycodestyle
|
||||
"E", "W",
|
||||
# pyflakes
|
||||
"F",
|
||||
# pygrep-hooks
|
||||
"PGH",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# ruff
|
||||
"RUF",
|
||||
# tryceratops
|
||||
"TRY",
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"**/tests/*" = ["S101"]
|
||||
|
||||
[format]
|
||||
preview = true
|
||||
skip-magic-trailing-comma = false
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025, Arcade AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ github Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the uv environment and install all packages with dependencies
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras --no-sources
|
||||
@if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: install-local
|
||||
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras
|
||||
@if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
uv build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
@uv run --no-sources coverage report
|
||||
@echo "Generating coverage report"
|
||||
@uv run --no-sources coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
uv version --no-sources --bump patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@if [ -f .pre-commit-config.yaml ]; then\
|
||||
echo "🚀 Linting code: Running pre-commit";\
|
||||
uv run --no-sources pre-commit run -a;\
|
||||
fi
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@uv run --no-sources mypy --config-file=pyproject.toml
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from arcade_brightdata.tools import scrape_as_markdown, search_engine, web_data_feed
|
||||
|
||||
__all__ = ["scrape_as_markdown", "search_engine", "web_data_feed"]
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
import sys
|
||||
from typing import cast
|
||||
|
||||
from arcade_mcp_server import MCPApp
|
||||
from arcade_mcp_server.mcp_app import TransportType
|
||||
|
||||
import arcade_brightdata
|
||||
|
||||
app = MCPApp(
|
||||
name="BrightData",
|
||||
instructions=(
|
||||
"Use this server when you need to interact with Bright Data to help users "
|
||||
"scrape web pages, search the web, and extract structured data from websites."
|
||||
),
|
||||
)
|
||||
|
||||
app.add_tools_from_module(arcade_brightdata)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
|
||||
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
|
||||
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
|
||||
app.run(transport=cast(TransportType, transport), host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
import json
|
||||
from typing import ClassVar
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class BrightDataClient:
|
||||
"""Engine for interacting with Bright Data API with connection management."""
|
||||
|
||||
_clients: ClassVar[dict[str, "BrightDataClient"]] = {}
|
||||
|
||||
def __init__(self, api_key: str, zone: str = "web_unlocker1") -> None:
|
||||
"""
|
||||
Initialize with API token and default zone.
|
||||
Args:
|
||||
api_key (str): Your Bright Data API token
|
||||
zone (str): Bright Data zone name
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
self.zone = zone
|
||||
self.endpoint = "https://api.brightdata.com/request"
|
||||
|
||||
@classmethod
|
||||
def create_client(cls, api_key: str, zone: str = "web_unlocker1") -> "BrightDataClient":
|
||||
"""Create or get cached client instance using API key only."""
|
||||
if api_key not in cls._clients:
|
||||
cls._clients[api_key] = cls(api_key, zone)
|
||||
|
||||
# Update zone for this request (user controls zone per request)
|
||||
client = cls._clients[api_key]
|
||||
client.zone = zone
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the client cache."""
|
||||
cls._clients.clear()
|
||||
|
||||
def make_request(self, payload: dict) -> str:
|
||||
"""
|
||||
Make a request to Bright Data API.
|
||||
Args:
|
||||
payload (Dict): Request payload
|
||||
Returns:
|
||||
str: Response text
|
||||
"""
|
||||
response = requests.post(
|
||||
self.endpoint, headers=self.headers, data=json.dumps(payload), timeout=30
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
result: str = response.text
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def encode_query(query: str) -> str:
|
||||
"""URL encode a search query."""
|
||||
return quote(query)
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
from arcade_brightdata.tools.bright_data_tools import (
|
||||
scrape_as_markdown,
|
||||
search_engine,
|
||||
web_data_feed,
|
||||
)
|
||||
|
||||
__all__ = ["scrape_as_markdown", "search_engine", "web_data_feed"]
|
||||
|
|
@ -1,361 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, cast
|
||||
|
||||
import requests
|
||||
from arcade_mcp_server import Context, tool
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from arcade_mcp_server.metadata import (
|
||||
Behavior,
|
||||
Classification,
|
||||
Operation,
|
||||
ServiceDomain,
|
||||
ToolMetadata,
|
||||
)
|
||||
|
||||
from arcade_brightdata.bright_data_client import BrightDataClient
|
||||
|
||||
|
||||
class DeviceType(str, Enum):
|
||||
MOBILE = "mobile"
|
||||
IOS = "ios"
|
||||
IPHONE = "iphone"
|
||||
IPAD = "ipad"
|
||||
ANDROID = "android"
|
||||
ANDROID_TABLET = "android_tablet"
|
||||
|
||||
|
||||
class SearchEngine(str, Enum):
|
||||
GOOGLE = "google"
|
||||
BING = "bing"
|
||||
YANDEX = "yandex"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
IMAGES = "images"
|
||||
SHOPPING = "shopping"
|
||||
NEWS = "news"
|
||||
JOBS = "jobs"
|
||||
|
||||
|
||||
class SourceType(str, Enum):
|
||||
AMAZON_PRODUCT = "amazon_product"
|
||||
AMAZON_PRODUCT_REVIEWS = "amazon_product_reviews"
|
||||
LINKEDIN_PERSON_PROFILE = "linkedin_person_profile"
|
||||
LINKEDIN_COMPANY_PROFILE = "linkedin_company_profile"
|
||||
ZOOMINFO_COMPANY_PROFILE = "zoominfo_company_profile"
|
||||
INSTAGRAM_PROFILES = "instagram_profiles"
|
||||
INSTAGRAM_POSTS = "instagram_posts"
|
||||
INSTAGRAM_REELS = "instagram_reels"
|
||||
INSTAGRAM_COMMENTS = "instagram_comments"
|
||||
FACEBOOK_POSTS = "facebook_posts"
|
||||
FACEBOOK_MARKETPLACE_LISTINGS = "facebook_marketplace_listings"
|
||||
FACEBOOK_COMPANY_REVIEWS = "facebook_company_reviews"
|
||||
X_POSTS = "x_posts"
|
||||
ZILLOW_PROPERTIES_LISTING = "zillow_properties_listing"
|
||||
BOOKING_HOTEL_LISTINGS = "booking_hotel_listings"
|
||||
YOUTUBE_VIDEOS = "youtube_videos"
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["BRIGHTDATA_API_KEY", "BRIGHTDATA_ZONE"],
|
||||
metadata=ToolMetadata(
|
||||
classification=Classification(
|
||||
service_domains=[ServiceDomain.WEB_SCRAPING],
|
||||
),
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
def scrape_as_markdown(
|
||||
context: Context,
|
||||
url: Annotated[str, "URL to scrape"],
|
||||
) -> Annotated[str, "Scraped webpage content as Markdown"]:
|
||||
"""
|
||||
Scrape a webpage and return content in Markdown format using Bright Data.
|
||||
|
||||
Examples:
|
||||
scrape_as_markdown("https://example.com") -> "# Example Page\n\nContent..."
|
||||
scrape_as_markdown("https://news.ycombinator.com") -> "# Hacker News\n..."
|
||||
"""
|
||||
api_key = context.get_secret("BRIGHTDATA_API_KEY")
|
||||
zone = context.get_secret("BRIGHTDATA_ZONE")
|
||||
client = BrightDataClient.create_client(api_key=api_key, zone=zone)
|
||||
|
||||
payload = {"url": url, "zone": zone, "format": "raw", "data_format": "markdown"}
|
||||
return client.make_request(payload)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["BRIGHTDATA_API_KEY", "BRIGHTDATA_ZONE"],
|
||||
metadata=ToolMetadata(
|
||||
classification=Classification(
|
||||
service_domains=[ServiceDomain.WEB_SCRAPING],
|
||||
),
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
def search_engine( # noqa: C901
|
||||
context: Context,
|
||||
query: Annotated[str, "Search query"],
|
||||
engine: Annotated[SearchEngine, "Search engine to use"] = SearchEngine.GOOGLE,
|
||||
language: Annotated[str | None, "Two-letter language code"] = None,
|
||||
country_code: Annotated[str | None, "Two-letter country code"] = None,
|
||||
search_type: Annotated[SearchType | None, "Type of search"] = None,
|
||||
start: Annotated[int | None, "Results pagination offset"] = None,
|
||||
num_results: Annotated[int, "Number of results to return. The default is 10"] = 10,
|
||||
location: Annotated[str | None, "Location for search results"] = None,
|
||||
device: Annotated[DeviceType | None, "Device type"] = None,
|
||||
return_json: Annotated[bool, "Return JSON instead of Markdown"] = False,
|
||||
) -> Annotated[str, "Search results as Markdown or JSON"]:
|
||||
"""
|
||||
Search using Google, Bing, or Yandex with advanced parameters using Bright Data.
|
||||
|
||||
Examples:
|
||||
search_engine("climate change") -> "# Search Results\n\n## Climate Change - Wikipedia\n..."
|
||||
search_engine("Python tutorials", engine="bing", num_results=5) -> "# Bing Results\n..."
|
||||
search_engine("cats", search_type="images", country_code="us") -> "# Image Results\n..."
|
||||
"""
|
||||
api_key = context.get_secret("BRIGHTDATA_API_KEY")
|
||||
zone = context.get_secret("BRIGHTDATA_ZONE")
|
||||
client = BrightDataClient.create_client(api_key=api_key, zone=zone)
|
||||
|
||||
encoded_query = BrightDataClient.encode_query(query)
|
||||
|
||||
base_urls = {
|
||||
SearchEngine.GOOGLE: f"https://www.google.com/search?q={encoded_query}",
|
||||
SearchEngine.BING: f"https://www.bing.com/search?q={encoded_query}",
|
||||
SearchEngine.YANDEX: f"https://yandex.com/search/?text={encoded_query}",
|
||||
}
|
||||
|
||||
search_url = base_urls[engine]
|
||||
|
||||
if engine == SearchEngine.GOOGLE:
|
||||
params = []
|
||||
|
||||
if language:
|
||||
params.append(f"hl={language}")
|
||||
|
||||
if country_code:
|
||||
params.append(f"gl={country_code}")
|
||||
|
||||
if search_type:
|
||||
if search_type == SearchType.JOBS:
|
||||
params.append("ibp=htl;jobs")
|
||||
else:
|
||||
search_types = {
|
||||
SearchType.IMAGES: "isch",
|
||||
SearchType.SHOPPING: "shop",
|
||||
SearchType.NEWS: "nws",
|
||||
}
|
||||
tbm_value = search_types.get(search_type, search_type)
|
||||
params.append(f"tbm={tbm_value}")
|
||||
|
||||
if start is not None:
|
||||
params.append(f"start={start}")
|
||||
|
||||
if num_results:
|
||||
params.append(f"num={num_results}")
|
||||
|
||||
if location:
|
||||
params.append(f"uule={BrightDataClient.encode_query(location)}")
|
||||
|
||||
if device:
|
||||
device_value = "1"
|
||||
|
||||
if device.value in ["ios", "iphone"]:
|
||||
device_value = "ios"
|
||||
elif device.value == "ipad":
|
||||
device_value = "ios_tablet"
|
||||
elif device.value == "android":
|
||||
device_value = "android"
|
||||
elif device.value == "android_tablet":
|
||||
device_value = "android_tablet"
|
||||
|
||||
params.append(f"brd_mobile={device_value}")
|
||||
|
||||
if return_json:
|
||||
params.append("brd_json=1")
|
||||
|
||||
if params:
|
||||
search_url += "&" + "&".join(params)
|
||||
|
||||
payload = {
|
||||
"url": search_url,
|
||||
"zone": zone,
|
||||
"format": "raw",
|
||||
"data_format": "markdown" if not return_json else "raw",
|
||||
}
|
||||
|
||||
return client.make_request(payload)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["BRIGHTDATA_API_KEY"],
|
||||
metadata=ToolMetadata(
|
||||
classification=Classification(
|
||||
service_domains=[ServiceDomain.WEB_SCRAPING],
|
||||
),
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=False,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
def web_data_feed(
|
||||
context: Context,
|
||||
source_type: Annotated[SourceType, "Type of data source"],
|
||||
url: Annotated[str, "URL of the web resource to extract data from"],
|
||||
num_of_reviews: Annotated[
|
||||
int | None,
|
||||
(
|
||||
"Number of reviews to retrieve. Only applicable for "
|
||||
"facebook_company_reviews. Default is None"
|
||||
),
|
||||
] = None,
|
||||
timeout: Annotated[int, "Maximum time in seconds to wait for data retrieval"] = 600,
|
||||
polling_interval: Annotated[int, "Time in seconds between polling attempts"] = 1,
|
||||
) -> Annotated[str, "Structured data from the requested source as JSON"]:
|
||||
"""
|
||||
Extract structured data from various websites like LinkedIn, Amazon, Instagram, etc.
|
||||
NEVER MADE UP LINKS - IF LINKS ARE NEEDED, EXECUTE search_engine FIRST.
|
||||
Supported source types:
|
||||
- amazon_product, amazon_product_reviews
|
||||
- linkedin_person_profile, linkedin_company_profile
|
||||
- zoominfo_company_profile
|
||||
- instagram_profiles, instagram_posts, instagram_reels, instagram_comments
|
||||
- facebook_posts, facebook_marketplace_listings, facebook_company_reviews
|
||||
- x_posts
|
||||
- zillow_properties_listing
|
||||
- booking_hotel_listings
|
||||
- youtube_videos
|
||||
|
||||
Examples:
|
||||
web_data_feed("amazon_product", "https://amazon.com/dp/B08N5WRWNW")
|
||||
-> "{\"title\": \"Product Name\", ...}"
|
||||
web_data_feed("linkedin_person_profile", "https://linkedin.com/in/johndoe")
|
||||
-> "{\"name\": \"John Doe\", ...}"
|
||||
web_data_feed(
|
||||
"facebook_company_reviews", "https://facebook.com/company", num_of_reviews=50
|
||||
) -> "[{\"review\": \"...\", ...}]"
|
||||
"""
|
||||
api_key = context.get_secret("BRIGHTDATA_API_KEY")
|
||||
client = BrightDataClient.create_client(api_key=api_key)
|
||||
if num_of_reviews is not None and source_type != SourceType.FACEBOOK_COMPANY_REVIEWS:
|
||||
msg = (
|
||||
f"num_of_reviews parameter is only applicable for facebook_company_reviews, "
|
||||
f"not for {source_type.value}"
|
||||
)
|
||||
prompt = (
|
||||
"The num_of_reviews parameter should only be used with "
|
||||
"facebook_company_reviews source type."
|
||||
)
|
||||
raise RetryableToolError(msg, additional_prompt_content=prompt)
|
||||
data = _extract_structured_data(
|
||||
client=client,
|
||||
source_type=source_type,
|
||||
url=url,
|
||||
num_of_reviews=num_of_reviews,
|
||||
timeout=timeout,
|
||||
polling_interval=polling_interval,
|
||||
)
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
|
||||
def _extract_structured_data(
|
||||
client: BrightDataClient,
|
||||
source_type: SourceType,
|
||||
url: str,
|
||||
num_of_reviews: int | None = None,
|
||||
timeout: int = 600,
|
||||
polling_interval: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract structured data from various sources.
|
||||
"""
|
||||
datasets = {
|
||||
SourceType.AMAZON_PRODUCT: "gd_l7q7dkf244hwjntr0",
|
||||
SourceType.AMAZON_PRODUCT_REVIEWS: "gd_le8e811kzy4ggddlq",
|
||||
SourceType.LINKEDIN_PERSON_PROFILE: "gd_l1viktl72bvl7bjuj0",
|
||||
SourceType.LINKEDIN_COMPANY_PROFILE: "gd_l1vikfnt1wgvvqz95w",
|
||||
SourceType.ZOOMINFO_COMPANY_PROFILE: "gd_m0ci4a4ivx3j5l6nx",
|
||||
SourceType.INSTAGRAM_PROFILES: "gd_l1vikfch901nx3by4",
|
||||
SourceType.INSTAGRAM_POSTS: "gd_lk5ns7kz21pck8jpis",
|
||||
SourceType.INSTAGRAM_REELS: "gd_lyclm20il4r5helnj",
|
||||
SourceType.INSTAGRAM_COMMENTS: "gd_ltppn085pokosxh13",
|
||||
SourceType.FACEBOOK_POSTS: "gd_lyclm1571iy3mv57zw",
|
||||
SourceType.FACEBOOK_MARKETPLACE_LISTINGS: "gd_lvt9iwuh6fbcwmx1a",
|
||||
SourceType.FACEBOOK_COMPANY_REVIEWS: "gd_m0dtqpiu1mbcyc2g86",
|
||||
SourceType.X_POSTS: "gd_lwxkxvnf1cynvib9co",
|
||||
SourceType.ZILLOW_PROPERTIES_LISTING: "gd_lfqkr8wm13ixtbd8f5",
|
||||
SourceType.BOOKING_HOTEL_LISTINGS: "gd_m5mbdl081229ln6t4a",
|
||||
SourceType.YOUTUBE_VIDEOS: "gd_m5mbdl081229ln6t4a",
|
||||
}
|
||||
|
||||
dataset_id = datasets[source_type]
|
||||
|
||||
request_data = {"url": url}
|
||||
if source_type == SourceType.FACEBOOK_COMPANY_REVIEWS and num_of_reviews is not None:
|
||||
request_data["num_of_reviews"] = str(num_of_reviews)
|
||||
|
||||
trigger_response = requests.post(
|
||||
"https://api.brightdata.com/datasets/v3/trigger",
|
||||
params={"dataset_id": dataset_id, "include_errors": "true"},
|
||||
headers=client.headers,
|
||||
json=[request_data],
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
trigger_data = trigger_response.json()
|
||||
if not trigger_data.get("snapshot_id"):
|
||||
msg = "No snapshot ID returned from trigger request"
|
||||
prompt = "Invalid input provided, use search_engine to get the relevant data first"
|
||||
raise RetryableToolError(msg, additional_prompt_content=prompt)
|
||||
|
||||
snapshot_id = trigger_data["snapshot_id"]
|
||||
|
||||
attempts = 0
|
||||
max_attempts = timeout
|
||||
|
||||
while attempts < max_attempts:
|
||||
try:
|
||||
snapshot_response = requests.get(
|
||||
f"https://api.brightdata.com/datasets/v3/snapshot/{snapshot_id}",
|
||||
params={"format": "json"},
|
||||
headers=client.headers,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
snapshot_data = cast(dict[str, Any], snapshot_response.json())
|
||||
|
||||
if isinstance(snapshot_data, dict) and snapshot_data.get("status") in (
|
||||
"running",
|
||||
"building",
|
||||
):
|
||||
attempts += 1
|
||||
time.sleep(polling_interval)
|
||||
continue
|
||||
else:
|
||||
return snapshot_data
|
||||
|
||||
except Exception:
|
||||
attempts += 1
|
||||
time.sleep(polling_interval)
|
||||
|
||||
msg = f"Timeout after {max_attempts} seconds waiting for {source_type.value} data"
|
||||
raise TimeoutError(msg)
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
[build-system]
|
||||
requires = [ "hatchling",]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arcade_brightdata"
|
||||
version = "0.4.0"
|
||||
description = "Search, Crawl and Scrape any site, at scale, without getting blocked"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-mcp-server>=1.17.0,<2.0.0",
|
||||
"requests>=2.32.5",
|
||||
]
|
||||
[[project.authors]]
|
||||
name = "meirk-brd"
|
||||
email = "meirk@brightdata.com"
|
||||
|
||||
[project.scripts]
|
||||
arcade-brightdata = "arcade_brightdata.__main__:main"
|
||||
arcade_brightdata = "arcade_brightdata.__main__:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"arcade-mcp[all]>=1.2.0,<2.0.0",
|
||||
"pytest>=8.3.0,<8.4.0",
|
||||
"pytest-cov>=4.0.0,<4.1.0",
|
||||
"pytest-mock>=3.11.1,<3.12.0",
|
||||
"pytest-asyncio>=0.24.0,<0.25.0",
|
||||
"mypy>=1.5.1,<1.6.0",
|
||||
"pre-commit>=3.4.0,<3.5.0",
|
||||
"tox>=4.11.1,<4.12.0",
|
||||
"ruff>=0.7.4,<0.8.0",
|
||||
"types-requests>=2.32.0",
|
||||
]
|
||||
# Tell Arcade.dev that this package is a toolkit
|
||||
[project.entry-points.arcade_toolkits]
|
||||
toolkit_name = "arcade_brightdata"
|
||||
|
||||
[tool.mypy]
|
||||
files = [ "arcade_brightdata/**/*.py",]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
check_untyped_defs = "True"
|
||||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.uv.sources]
|
||||
arcade-mcp = { path = "../../", editable = true }
|
||||
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = [ "tests",]
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [ "arcade_brightdata",]
|
||||
|
|
@ -1,418 +0,0 @@
|
|||
from os import environ
|
||||
from unittest.mock import MagicMock as _MagicMock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from arcade_mcp_server import Context
|
||||
from arcade_mcp_server.exceptions import ToolExecutionError
|
||||
|
||||
from arcade_brightdata.bright_data_client import BrightDataClient
|
||||
from arcade_brightdata.tools.bright_data_tools import (
|
||||
DeviceType,
|
||||
SourceType,
|
||||
scrape_as_markdown,
|
||||
search_engine,
|
||||
web_data_feed,
|
||||
)
|
||||
|
||||
BRIGHTDATA_API_KEY = environ.get("TEST_BRIGHTDATA_API_KEY") or "api-key"
|
||||
BRIGHTDATA_ZONE = environ.get("TEST_BRIGHTDATA_ZONE") or "unblocker"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = _MagicMock(spec=Context)
|
||||
context.get_secret = _MagicMock(
|
||||
side_effect=lambda key: {
|
||||
"BRIGHTDATA_API_KEY": BRIGHTDATA_API_KEY,
|
||||
"BRIGHTDATA_ZONE": BRIGHTDATA_ZONE,
|
||||
}[key]
|
||||
)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_engines():
|
||||
"""Clean up bright data clients after each test to prevent connection leaks."""
|
||||
yield
|
||||
BrightDataClient.clear_cache()
|
||||
|
||||
|
||||
class TestBrightDataClient:
|
||||
def test_get_instance_creates_new_client(self):
|
||||
client1 = BrightDataClient.create_client("test_key_1", "zone1")
|
||||
client2 = BrightDataClient.create_client("test_key_2", "zone2")
|
||||
|
||||
assert client1 != client2
|
||||
assert client1.api_key == "test_key_1"
|
||||
assert client1.zone == "zone1"
|
||||
assert client2.api_key == "test_key_2"
|
||||
assert client2.zone == "zone2"
|
||||
|
||||
def test_get_instance_returns_cached_client(self):
|
||||
client1 = BrightDataClient.create_client("test_key", "zone1")
|
||||
client2 = BrightDataClient.create_client("test_key", "zone1")
|
||||
|
||||
assert client1 is client2
|
||||
|
||||
def test_clear_cache(self):
|
||||
client1 = BrightDataClient.create_client("test_key", "zone1")
|
||||
BrightDataClient.clear_cache()
|
||||
client2 = BrightDataClient.create_client("test_key", "zone1")
|
||||
|
||||
assert client1 is not client2
|
||||
|
||||
def test_encode_query(self):
|
||||
result = BrightDataClient.encode_query("hello world test")
|
||||
assert result == "hello%20world%20test"
|
||||
|
||||
@patch("requests.post")
|
||||
def test_make_request_success(self, mock_post):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "Success response"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = BrightDataClient("test_key", "test_zone")
|
||||
result = client.make_request({"url": "https://example.com"})
|
||||
|
||||
assert result == "Success response"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch("requests.post")
|
||||
def test_make_request_failure(self, mock_post):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "Bad Request"
|
||||
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
||||
"400 Client Error"
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = BrightDataClient("test_key", "test_zone")
|
||||
|
||||
with pytest.raises(requests.exceptions.HTTPError):
|
||||
client.make_request({"url": "https://example.com"})
|
||||
|
||||
|
||||
class TestScrapeAsMarkdown:
|
||||
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
|
||||
def test_scrape_as_markdown_success(self, mock_engine_class, mock_context):
|
||||
mock_client = Mock()
|
||||
mock_client.make_request.return_value = "# Test Page\n\nContent here"
|
||||
mock_engine_class.create_client.return_value = mock_client
|
||||
|
||||
result = scrape_as_markdown(mock_context, "https://example.com")
|
||||
|
||||
assert result == "# Test Page\n\nContent here"
|
||||
mock_engine_class.create_client.assert_called_once_with(
|
||||
api_key=BRIGHTDATA_API_KEY, zone=BRIGHTDATA_ZONE
|
||||
)
|
||||
mock_client.make_request.assert_called_once_with({
|
||||
"url": "https://example.com",
|
||||
"zone": BRIGHTDATA_ZONE,
|
||||
"format": "raw",
|
||||
"data_format": "markdown",
|
||||
})
|
||||
|
||||
|
||||
class TestSearchEngine:
|
||||
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
|
||||
def test_search_engine_google_basic(self, mock_engine_class, mock_context):
|
||||
mock_client = Mock()
|
||||
mock_client.make_request.return_value = "# Search Results\n\nResult 1\nResult 2"
|
||||
mock_engine_class.create_client.return_value = mock_client
|
||||
mock_engine_class.encode_query.return_value = "test%20query"
|
||||
|
||||
result = search_engine(mock_context, "test query")
|
||||
|
||||
assert result == "# Search Results\n\nResult 1\nResult 2"
|
||||
mock_engine_class.create_client.assert_called_once_with(
|
||||
api_key=BRIGHTDATA_API_KEY, zone=BRIGHTDATA_ZONE
|
||||
)
|
||||
|
||||
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
|
||||
def test_search_engine_bing(self, mock_engine_class, mock_context):
|
||||
mock_client = Mock()
|
||||
mock_client.make_request.return_value = "# Bing Results"
|
||||
mock_engine_class.create_client.return_value = mock_client
|
||||
mock_engine_class.encode_query.return_value = "test%20query"
|
||||
|
||||
result = search_engine(mock_context, "test query", engine="bing")
|
||||
|
||||
assert result == "# Bing Results"
|
||||
expected_payload = {
|
||||
"url": "https://www.bing.com/search?q=test%20query",
|
||||
"zone": BRIGHTDATA_ZONE,
|
||||
"format": "raw",
|
||||
"data_format": "markdown",
|
||||
}
|
||||
mock_client.make_request.assert_called_once_with(expected_payload)
|
||||
|
||||
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
|
||||
def test_search_engine_google_with_parameters(self, mock_engine_class, mock_context):
|
||||
mock_client = Mock()
|
||||
mock_client.make_request.return_value = "# Google Results with params"
|
||||
mock_engine_class.create_client.return_value = mock_client
|
||||
mock_engine_class.encode_query.side_effect = lambda x: x.replace(" ", "%20")
|
||||
|
||||
result = search_engine(
|
||||
mock_context,
|
||||
"test query",
|
||||
language="en",
|
||||
country_code="us",
|
||||
search_type="images",
|
||||
start=10,
|
||||
num_results=20,
|
||||
location="New York",
|
||||
device=DeviceType.MOBILE,
|
||||
return_json=True,
|
||||
)
|
||||
|
||||
assert result == "# Google Results with params"
|
||||
call_args = mock_client.make_request.call_args[0][0]
|
||||
|
||||
assert "hl=en" in call_args["url"]
|
||||
assert "gl=us" in call_args["url"]
|
||||
assert "tbm=isch" in call_args["url"]
|
||||
assert "start=10" in call_args["url"]
|
||||
assert "num=20" in call_args["url"]
|
||||
assert "brd_mobile=1" in call_args["url"]
|
||||
assert "brd_json=1" in call_args["url"]
|
||||
assert call_args["data_format"] == "raw"
|
||||
|
||||
def test_search_engine_invalid_engine(self, mock_context):
|
||||
with pytest.raises(ToolExecutionError):
|
||||
search_engine(mock_context, "test query", engine="invalid_engine")
|
||||
|
||||
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
|
||||
def test_search_engine_google_jobs(self, mock_engine_class, mock_context):
|
||||
mock_client = Mock()
|
||||
mock_client.make_request.return_value = "# Job Results"
|
||||
mock_engine_class.create_client.return_value = mock_client
|
||||
mock_engine_class.encode_query.return_value = "python%20developer"
|
||||
|
||||
result = search_engine(mock_context, "python developer", search_type="jobs")
|
||||
|
||||
assert result == "# Job Results"
|
||||
call_args = mock_client.make_request.call_args[0][0]
|
||||
assert "ibp=htl;jobs" in call_args["url"]
|
||||
|
||||
|
||||
class TestWebDataFeed:
|
||||
@patch("arcade_brightdata.tools.bright_data_tools._extract_structured_data")
|
||||
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
|
||||
def test_web_data_feed_success(self, mock_engine_class, mock_extract, mock_context):
|
||||
mock_client = Mock()
|
||||
mock_engine_class.create_client.return_value = mock_client
|
||||
mock_extract.return_value = {"title": "Test Product", "price": "$19.99"}
|
||||
|
||||
result = web_data_feed(mock_context, "amazon_product", "https://amazon.com/dp/B08N5WRWNW")
|
||||
|
||||
expected_json = '{\n "title": "Test Product",\n "price": "$19.99"\n}'
|
||||
assert result == expected_json
|
||||
|
||||
mock_engine_class.create_client.assert_called_once_with(api_key=BRIGHTDATA_API_KEY)
|
||||
mock_extract.assert_called_once_with(
|
||||
client=mock_client,
|
||||
source_type=SourceType.AMAZON_PRODUCT,
|
||||
url="https://amazon.com/dp/B08N5WRWNW",
|
||||
num_of_reviews=None,
|
||||
timeout=600,
|
||||
polling_interval=1,
|
||||
)
|
||||
|
||||
@patch("arcade_brightdata.tools.bright_data_tools._extract_structured_data")
|
||||
@patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient")
|
||||
def test_web_data_feed_with_reviews(self, mock_engine_class, mock_extract, mock_context):
|
||||
mock_client = Mock()
|
||||
mock_engine_class.create_client.return_value = mock_client
|
||||
mock_extract.return_value = [{"review": "Great product!", "rating": 5}]
|
||||
|
||||
result = web_data_feed(
|
||||
mock_context,
|
||||
"facebook_company_reviews",
|
||||
"https://facebook.com/company",
|
||||
num_of_reviews=50,
|
||||
timeout=300,
|
||||
polling_interval=2,
|
||||
)
|
||||
|
||||
expected_json = '[\n {\n "review": "Great product!",\n "rating": 5\n }\n]'
|
||||
assert result == expected_json
|
||||
|
||||
mock_extract.assert_called_once_with(
|
||||
client=mock_client,
|
||||
source_type=SourceType.FACEBOOK_COMPANY_REVIEWS,
|
||||
url="https://facebook.com/company",
|
||||
num_of_reviews=50,
|
||||
timeout=300,
|
||||
polling_interval=2,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractStructuredData:
|
||||
@patch("requests.get")
|
||||
@patch("requests.post")
|
||||
def test_extract_structured_data_success(self, mock_post, mock_get):
|
||||
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
|
||||
|
||||
client = BrightDataClient("test_key", "test_zone")
|
||||
|
||||
mock_trigger_response = Mock()
|
||||
mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"}
|
||||
mock_post.return_value = mock_trigger_response
|
||||
|
||||
mock_snapshot_response = Mock()
|
||||
mock_snapshot_response.json.return_value = {"data": "extracted_data"}
|
||||
mock_get.return_value = mock_snapshot_response
|
||||
|
||||
result = _extract_structured_data(
|
||||
client=client,
|
||||
source_type=SourceType.AMAZON_PRODUCT,
|
||||
url="https://amazon.com/dp/TEST",
|
||||
timeout=10,
|
||||
polling_interval=0.1,
|
||||
)
|
||||
|
||||
assert result == {"data": "extracted_data"}
|
||||
|
||||
mock_post.assert_called_once()
|
||||
trigger_call = mock_post.call_args
|
||||
assert "gd_l7q7dkf244hwjntr0" in str(trigger_call) # Amazon product dataset ID
|
||||
|
||||
mock_get.assert_called_once()
|
||||
snapshot_call = mock_get.call_args
|
||||
assert "snap_123" in str(snapshot_call)
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("requests.post")
|
||||
def test_extract_structured_data_with_polling(self, mock_post, mock_get):
|
||||
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
|
||||
|
||||
client = BrightDataClient("test_key", "test_zone")
|
||||
|
||||
mock_trigger_response = Mock()
|
||||
mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"}
|
||||
mock_post.return_value = mock_trigger_response
|
||||
|
||||
running_response = Mock()
|
||||
running_response.json.return_value = {"status": "running"}
|
||||
|
||||
complete_response = Mock()
|
||||
complete_response.json.return_value = {"data": "final_data"}
|
||||
|
||||
mock_get.side_effect = [running_response, complete_response]
|
||||
|
||||
result = _extract_structured_data(
|
||||
client=client,
|
||||
source_type=SourceType.LINKEDIN_PERSON_PROFILE,
|
||||
url="https://linkedin.com/in/test",
|
||||
timeout=10,
|
||||
polling_interval=0.1,
|
||||
)
|
||||
|
||||
assert result == {"data": "final_data"}
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@patch("requests.post")
|
||||
def test_extract_structured_data_invalid_source_type(self, mock_post):
|
||||
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
|
||||
|
||||
client = BrightDataClient("test_key", "test_zone")
|
||||
|
||||
# Create a mock SourceType that doesn't exist in the datasets dict
|
||||
class InvalidSourceType:
|
||||
value = "invalid_source"
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
_extract_structured_data(
|
||||
client=client, source_type=InvalidSourceType(), url="https://example.com"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("requests.post")
|
||||
def test_extract_structured_data_no_snapshot_id(self, mock_post, mock_get):
|
||||
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
|
||||
|
||||
client = BrightDataClient("test_key", "test_zone")
|
||||
|
||||
# Mock trigger response without snapshot_id
|
||||
mock_trigger_response = Mock()
|
||||
mock_trigger_response.json.return_value = {}
|
||||
mock_post.return_value = mock_trigger_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
_extract_structured_data(
|
||||
client=client,
|
||||
source_type=SourceType.AMAZON_PRODUCT,
|
||||
url="https://amazon.com/dp/TEST",
|
||||
)
|
||||
|
||||
assert "No snapshot ID returned from trigger request" in str(exc_info.value)
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("requests.post")
|
||||
@patch("time.sleep")
|
||||
def test_extract_structured_data_timeout(self, mock_sleep, mock_post, mock_get):
|
||||
from arcade_brightdata.tools.bright_data_tools import _extract_structured_data
|
||||
|
||||
client = BrightDataClient("test_key", "test_zone")
|
||||
|
||||
# Mock trigger response
|
||||
mock_trigger_response = Mock()
|
||||
mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"}
|
||||
mock_post.return_value = mock_trigger_response
|
||||
|
||||
# Mock snapshot response that always returns running
|
||||
mock_snapshot_response = Mock()
|
||||
mock_snapshot_response.json.return_value = {"status": "running"}
|
||||
mock_get.return_value = mock_snapshot_response
|
||||
|
||||
with pytest.raises(TimeoutError) as exc_info:
|
||||
_extract_structured_data(
|
||||
client=client,
|
||||
source_type=SourceType.AMAZON_PRODUCT,
|
||||
url="https://amazon.com/dp/TEST",
|
||||
timeout=2,
|
||||
polling_interval=0.1,
|
||||
)
|
||||
|
||||
assert "Timeout after 2 seconds waiting for amazon_product data" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests that test the full flow without mocking internal components."""
|
||||
|
||||
@patch("requests.post")
|
||||
def test_scrape_as_markdown_integration(self, mock_post, mock_context):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "# Integration Test\n\nThis is a test page"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = scrape_as_markdown(mock_context, "https://example.com")
|
||||
|
||||
assert result == "# Integration Test\n\nThis is a test page"
|
||||
|
||||
# Verify the request was made correctly
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["headers"]["Authorization"] == f"Bearer {BRIGHTDATA_API_KEY}"
|
||||
assert "https://api.brightdata.com/request" in str(call_args)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_search_engine_integration(self, mock_post, mock_context):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "# Search Results\n\n1. First result\n2. Second result"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = search_engine(mock_context, "test query", engine="google")
|
||||
|
||||
assert result == "# Search Results\n\n1. First result\n2. Second result"
|
||||
|
||||
call_args = mock_post.call_args
|
||||
payload = call_args[1]["data"]
|
||||
assert '"url": "https://www.google.com/search?q=test%20query' in payload
|
||||
assert '"data_format": "markdown"' in payload
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ clickhouse Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the uv environment and install all packages with dependencies
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras --no-sources
|
||||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: install-local
|
||||
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras
|
||||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
uv build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
coverage report
|
||||
@echo "Generating coverage report"
|
||||
coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
uv version --bump patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@echo "🚀 Linting code: Running pre-commit"
|
||||
@uv run pre-commit run -a
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@uv run mypy --config-file=pyproject.toml
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import sys
|
||||
from typing import cast
|
||||
|
||||
from arcade_mcp_server import MCPApp
|
||||
from arcade_mcp_server.mcp_app import TransportType
|
||||
|
||||
import arcade_clickhouse
|
||||
|
||||
app = MCPApp(
|
||||
name="ClickHouse",
|
||||
instructions=(
|
||||
"Use this server when you need to interact with ClickHouse to help users "
|
||||
"query, explore, and manage their ClickHouse databases."
|
||||
),
|
||||
)
|
||||
|
||||
app.add_tools_from_module(arcade_clickhouse)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
|
||||
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
|
||||
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
|
||||
|
||||
app.run(transport=cast(TransportType, transport), host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
import contextlib
|
||||
from typing import Any, ClassVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import clickhouse_connect
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
|
||||
MAX_ROWS_RETURNED = 1000
|
||||
TEST_QUERY = "SELECT 1"
|
||||
|
||||
|
||||
class DatabaseEngine:
|
||||
_instance: ClassVar[None] = None
|
||||
_clients: ClassVar[dict[str, Any]] = {}
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, connection_string: str) -> Any:
|
||||
parsed_url = urlparse(connection_string)
|
||||
|
||||
# Extract connection parameters from the URL
|
||||
host = parsed_url.hostname or "localhost"
|
||||
port = parsed_url.port
|
||||
database = parsed_url.path.lstrip("/") or "default"
|
||||
username = parsed_url.username
|
||||
password = parsed_url.password
|
||||
|
||||
# Handle different ClickHouse protocols
|
||||
# clickhouse-connect only supports HTTP and HTTPS interfaces
|
||||
if parsed_url.scheme in ["clickhouse+native"]:
|
||||
# Convert native protocol to HTTP for clickhouse-connect compatibility
|
||||
# Convert native port 9000 to HTTP port 8123
|
||||
port = 8123 if port == 9000 else port or 8123
|
||||
interface = "http"
|
||||
elif parsed_url.scheme in ["clickhouse+https"]:
|
||||
# For HTTPS protocol
|
||||
port = port or 8443
|
||||
interface = "https"
|
||||
else:
|
||||
# For HTTP or unspecified, use port 8123 by default
|
||||
port = port or 8123
|
||||
interface = "http"
|
||||
|
||||
key = f"{interface}://{host}:{port}/{database}"
|
||||
|
||||
if key not in cls._clients:
|
||||
try:
|
||||
# Create ClickHouse client
|
||||
client_args: dict[str, Any] = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"database": database,
|
||||
"interface": interface,
|
||||
}
|
||||
|
||||
if username:
|
||||
client_args["username"] = username
|
||||
if password:
|
||||
client_args["password"] = password
|
||||
|
||||
client = clickhouse_connect.get_client(**client_args)
|
||||
cls._clients[key] = client
|
||||
|
||||
# Test the connection
|
||||
client.command(TEST_QUERY)
|
||||
|
||||
except Exception as e:
|
||||
# Remove failed client from cache
|
||||
cls._clients.pop(key, None)
|
||||
raise RetryableToolError(
|
||||
f"Connection failed: {e}",
|
||||
developer_message="Connection to ClickHouse failed.",
|
||||
additional_prompt_content="Check the connection string and try again.",
|
||||
) from e
|
||||
|
||||
return cls._clients[key]
|
||||
|
||||
@classmethod
|
||||
async def get_engine(cls, connection_string: str) -> Any:
|
||||
client = await cls.get_instance(connection_string)
|
||||
|
||||
class ConnectionContextManager:
|
||||
def __init__(self, client: Any) -> None:
|
||||
self.client = client
|
||||
|
||||
async def __aenter__(self) -> Any:
|
||||
return self.client
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
# Connection cleanup is handled by clickhouse-connect
|
||||
pass
|
||||
|
||||
return ConnectionContextManager(client)
|
||||
|
||||
@classmethod
|
||||
async def cleanup(cls) -> None:
|
||||
"""Clean up all cached clients. Call this when shutting down."""
|
||||
for client in cls._clients.values():
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
cls._clients.clear()
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the client cache without disposing clients. Use with caution."""
|
||||
cls._clients.clear()
|
||||
|
||||
@classmethod
|
||||
def sanitize_query( # noqa: C901
|
||||
cls,
|
||||
select_clause: str,
|
||||
from_clause: str,
|
||||
limit: int,
|
||||
offset: int,
|
||||
join_clause: str | None,
|
||||
where_clause: str | None,
|
||||
having_clause: str | None,
|
||||
group_by_clause: str | None,
|
||||
order_by_clause: str | None,
|
||||
with_clause: str | None,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
# Remove the leading keywords from the clauses if they are present
|
||||
if select_clause.strip().split(" ")[0].upper() == "SELECT":
|
||||
select_clause = select_clause.strip()[6:]
|
||||
|
||||
if from_clause.strip().split(" ")[0].upper() == "FROM":
|
||||
from_clause = from_clause.strip()[4:]
|
||||
|
||||
if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN":
|
||||
join_clause = join_clause.strip()[4:]
|
||||
|
||||
if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE":
|
||||
where_clause = where_clause.strip()[5:]
|
||||
|
||||
if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY":
|
||||
group_by_clause = group_by_clause.strip()[8:]
|
||||
|
||||
if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY":
|
||||
order_by_clause = order_by_clause.strip()[8:]
|
||||
|
||||
if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING":
|
||||
having_clause = having_clause.strip()[6:]
|
||||
|
||||
first_select_word = select_clause.strip().split(" ")[0].upper()
|
||||
if first_select_word in [
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"CREATE",
|
||||
"ALTER",
|
||||
"DROP",
|
||||
"TRUNCATE",
|
||||
"REINDEX",
|
||||
"VACUUM",
|
||||
"ANALYZE",
|
||||
"COMMENT",
|
||||
"OPTIMIZE", # ClickHouse-specific
|
||||
"SYSTEM", # ClickHouse-specific
|
||||
]:
|
||||
raise RetryableToolError(
|
||||
"Only SELECT queries are allowed.",
|
||||
)
|
||||
|
||||
if select_clause.strip() == "*":
|
||||
raise RetryableToolError(
|
||||
"Do not use * in the select clause. Use a comma separated list of columns you wish to return.",
|
||||
)
|
||||
|
||||
if limit > MAX_ROWS_RETURNED:
|
||||
raise RetryableToolError(
|
||||
f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.",
|
||||
)
|
||||
|
||||
if offset < 0:
|
||||
raise RetryableToolError(
|
||||
"Offset must be greater than or equal to 0.",
|
||||
developer_message="Offset must be greater than or equal to 0.",
|
||||
)
|
||||
|
||||
if limit <= 0:
|
||||
raise RetryableToolError(
|
||||
"Limit must be greater than 0.",
|
||||
developer_message="Limit must be greater than 0.",
|
||||
)
|
||||
|
||||
# Build query with identifiers directly interpolated, but use parameters for values
|
||||
parts = []
|
||||
if with_clause:
|
||||
parts.append(f"WITH {with_clause}")
|
||||
parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608
|
||||
if join_clause:
|
||||
parts.append(f"JOIN {join_clause}")
|
||||
if where_clause:
|
||||
parts.append(f"WHERE {where_clause}")
|
||||
if group_by_clause:
|
||||
parts.append(f"GROUP BY {group_by_clause}")
|
||||
if having_clause:
|
||||
parts.append(f"HAVING {having_clause}")
|
||||
if order_by_clause:
|
||||
parts.append(f"ORDER BY {order_by_clause}")
|
||||
parts.append("LIMIT :limit OFFSET :offset")
|
||||
query = " ".join(parts)
|
||||
|
||||
# Only use parameters for values, not identifiers
|
||||
parameters = {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
return query, parameters
|
||||
|
|
@ -1,347 +0,0 @@
|
|||
from typing import Annotated, Any
|
||||
|
||||
from arcade_mcp_server import Context, tool
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata
|
||||
|
||||
from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def discover_schemas(
|
||||
context: Context,
|
||||
) -> list[str]:
|
||||
"""Discover all the schemas in the ClickHouse database.
|
||||
|
||||
Note: ClickHouse doesn't have schemas like PostgreSQL, so this returns a default schema name.
|
||||
"""
|
||||
# ClickHouse doesn't have schemas like PostgreSQL, but we return a default for compatibility
|
||||
return ["default"]
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def discover_databases(
|
||||
context: Context,
|
||||
) -> list[str]:
|
||||
"""Discover all the databases in the ClickHouse database."""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
|
||||
) as client:
|
||||
databases = await _get_databases(client)
|
||||
return databases
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def discover_tables(
|
||||
context: Context,
|
||||
) -> list[str]:
|
||||
"""Discover all the tables in the ClickHouse database when the list of tables is not known.
|
||||
|
||||
ALWAYS use this tool before any other tool that requires a table name.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
|
||||
) as client:
|
||||
tables = await _get_tables(client, "default")
|
||||
return tables
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def get_table_schema(
|
||||
context: Context,
|
||||
schema_name: Annotated[str, "The schema to get the table schema of"],
|
||||
table_name: Annotated[str, "The table to get the schema of"],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the schema/structure of a ClickHouse table in the ClickHouse database when the schema is not known, and the name of the table is provided.
|
||||
|
||||
This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the <DiscoverTables> tool.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
|
||||
) as client:
|
||||
return await _get_table_schema(client, "default", table_name)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def execute_select_query(
|
||||
context: Context,
|
||||
select_clause: Annotated[
|
||||
str,
|
||||
"This is the part of the SQL query that comes after the SELECT keyword wish a comma separated list of columns you wish to return. Do not include the SELECT keyword.",
|
||||
],
|
||||
from_clause: Annotated[
|
||||
str,
|
||||
"This is the part of the SQL query that comes after the FROM keyword. Do not include the FROM keyword.",
|
||||
],
|
||||
limit: Annotated[
|
||||
int,
|
||||
"The maximum number of rows to return. This is the LIMIT clause of the query. Default: 100.",
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int, "The number of rows to skip. This is the OFFSET clause of the query. Default: 0."
|
||||
] = 0,
|
||||
join_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the JOIN keyword. Do not include the JOIN keyword. If no join is needed, leave this blank.",
|
||||
] = None,
|
||||
where_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the WHERE keyword. Do not include the WHERE keyword. If no where clause is needed, leave this blank.",
|
||||
] = None,
|
||||
having_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the HAVING keyword. Do not include the HAVING keyword. If no having clause is needed, leave this blank.",
|
||||
] = None,
|
||||
group_by_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the GROUP BY keyword. Do not include the GROUP BY keyword. If no group by clause is needed, leave this blank.",
|
||||
] = None,
|
||||
order_by_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the ORDER BY keyword. Do not include the ORDER BY keyword. If no order by clause is needed, leave this blank.",
|
||||
] = None,
|
||||
with_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the WITH keyword when basing the query on a virtual table. If no WITH clause is needed, leave this blank.",
|
||||
] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
You have a connection to a ClickHouse database.
|
||||
Execute a SELECT query and return the results against the ClickHouse database. No other queries (INSERT, UPDATE, DELETE, etc.) are allowed.
|
||||
|
||||
ONLY use this tool if you have already loaded the schema of the tables you need to query. Use the <GetTableSchema> tool to load the schema if not already known.
|
||||
|
||||
The final query will be constructed as follows:
|
||||
SELECT {select_query_part} FROM {from_clause} JOIN {join_clause} WHERE {where_clause} HAVING {having_clause} ORDER BY {order_by_clause} LIMIT {limit} OFFSET {offset}
|
||||
|
||||
When running queries, follow these rules which will help avoid errors:
|
||||
* Never "select *" from a table. Always select the columns you need.
|
||||
* Always order your results by the most important columns first. If you aren't sure, order by the primary key.
|
||||
* Always use case-insensitive queries to match strings in the query.
|
||||
* Always trim strings in the query.
|
||||
* Prefer LIKE queries over direct string matches or regex queries.
|
||||
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
|
||||
* ClickHouse is case-sensitive, so be careful with table and column names.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING")
|
||||
) as client:
|
||||
try:
|
||||
return await _execute_query(
|
||||
client,
|
||||
select_clause=select_clause,
|
||||
from_clause=from_clause,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
join_clause=join_clause,
|
||||
where_clause=where_clause,
|
||||
having_clause=having_clause,
|
||||
group_by_clause=group_by_clause,
|
||||
order_by_clause=order_by_clause,
|
||||
with_clause=with_clause,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Query failed: {e}",
|
||||
developer_message=f"Query failed with parameters: select_clause={select_clause}, from_clause={from_clause}, limit={limit}, offset={offset}, join_clause={join_clause}, where_clause={where_clause}, having_clause={having_clause}, order_by_clause={order_by_clause}, with_clause={with_clause}.",
|
||||
additional_prompt_content="Load the database schema <GetTableSchema> or use the <DiscoverTables> tool to discover the tables and try again.",
|
||||
retry_after_ms=10,
|
||||
) from e
|
||||
|
||||
|
||||
async def _get_databases(client: Any) -> list[str]:
|
||||
"""Get all the databases in ClickHouse"""
|
||||
# ClickHouse uses SHOW DATABASES instead of information_schema
|
||||
result = client.query("SHOW DATABASES")
|
||||
databases = [row[0] for row in result.result_rows]
|
||||
|
||||
# Filter out system databases
|
||||
system_databases = {
|
||||
"system",
|
||||
"information_schema",
|
||||
"INFORMATION_SCHEMA",
|
||||
"default",
|
||||
"temporary_tables",
|
||||
"temporary_tables_metadata",
|
||||
}
|
||||
databases = [db for db in databases if db not in system_databases]
|
||||
databases.sort()
|
||||
|
||||
return databases
|
||||
|
||||
|
||||
async def _get_tables(client: Any, database_name: str) -> list[str]:
|
||||
"""Get all the tables in the specified ClickHouse database"""
|
||||
# ClickHouse uses SHOW TABLES FROM database_name
|
||||
result = client.query(f"SHOW TABLES FROM {database_name}")
|
||||
tables = [row[0] for row in result.result_rows]
|
||||
tables.sort()
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
async def _get_table_schema(client: Any, database_name: str, table_name: str) -> list[str]:
|
||||
"""Get the schema of a ClickHouse table"""
|
||||
# ClickHouse uses DESCRIBE TABLE database_name.table_name
|
||||
result = client.query(f"DESCRIBE TABLE {database_name}.{table_name}")
|
||||
columns = result.result_rows
|
||||
|
||||
# Get primary key information
|
||||
# ClickHouse doesn't have traditional primary keys like PostgreSQL
|
||||
# Instead, it has sorting keys and primary keys that are part of the table engine
|
||||
try:
|
||||
pk_result = client.query(f"SHOW CREATE TABLE {database_name}.{table_name}")
|
||||
if pk_result.result_rows:
|
||||
create_statement = pk_result.result_rows[0][0]
|
||||
# Parse the CREATE statement to extract primary key information
|
||||
primary_keys = _extract_primary_keys_from_create_statement(create_statement)
|
||||
else:
|
||||
primary_keys = set()
|
||||
except Exception:
|
||||
primary_keys = set()
|
||||
|
||||
results = []
|
||||
for column in columns:
|
||||
column_name = column[
|
||||
0
|
||||
] # ClickHouse DESCRIBE returns: name, type, default_type, default_expression, comment, codec_expression, ttl_expression
|
||||
column_type = column[1]
|
||||
|
||||
# Build column description
|
||||
description = f"{column_name}: {column_type}"
|
||||
|
||||
# Add primary key indicator
|
||||
if column_name in primary_keys:
|
||||
description += " (PRIMARY KEY)"
|
||||
|
||||
# Add default value if present
|
||||
if len(column) > 3 and column[3]: # default_expression
|
||||
description += f" DEFAULT {column[3]}"
|
||||
|
||||
# Add comment if present
|
||||
if len(column) > 4 and column[4]: # comment
|
||||
description += f" COMMENT '{column[4]}'"
|
||||
|
||||
results.append(description)
|
||||
|
||||
return results[:MAX_ROWS_RETURNED]
|
||||
|
||||
|
||||
def _extract_primary_keys_from_create_statement(create_statement: str) -> set[str]:
|
||||
"""Extract primary key columns from ClickHouse CREATE TABLE statement"""
|
||||
primary_keys = set()
|
||||
|
||||
# Look for PRIMARY KEY clause
|
||||
import re
|
||||
|
||||
pk_match = re.search(r"PRIMARY KEY\s*\(([^)]+)\)", create_statement, re.IGNORECASE)
|
||||
if pk_match:
|
||||
pk_columns = pk_match.group(1).split(",")
|
||||
for col in pk_columns:
|
||||
primary_keys.add(col.strip().strip("`"))
|
||||
|
||||
# Look for ORDER BY clause (which can also indicate primary key)
|
||||
order_match = re.search(r"ORDER BY\s*\(([^)]+)\)", create_statement, re.IGNORECASE)
|
||||
if order_match:
|
||||
order_columns = order_match.group(1).split(",")
|
||||
for col in order_columns:
|
||||
primary_keys.add(col.strip().strip("`"))
|
||||
|
||||
return primary_keys
|
||||
|
||||
|
||||
async def _execute_query(
|
||||
client: Any,
|
||||
select_clause: str,
|
||||
from_clause: str,
|
||||
limit: int,
|
||||
offset: int,
|
||||
join_clause: str | None,
|
||||
where_clause: str | None,
|
||||
having_clause: str | None,
|
||||
group_by_clause: str | None,
|
||||
order_by_clause: str | None,
|
||||
with_clause: str | None,
|
||||
) -> list[str]:
|
||||
"""Execute a query and return the results."""
|
||||
query, parameters = DatabaseEngine.sanitize_query(
|
||||
select_clause=select_clause,
|
||||
from_clause=from_clause,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
join_clause=join_clause,
|
||||
where_clause=where_clause,
|
||||
having_clause=having_clause,
|
||||
group_by_clause=group_by_clause,
|
||||
order_by_clause=order_by_clause,
|
||||
with_clause=with_clause,
|
||||
)
|
||||
print(f"Query: {query}")
|
||||
print(f"Parameters: {parameters}")
|
||||
|
||||
# For clickhouse-connect, we need to substitute parameters manually
|
||||
# since it doesn't use SQLAlchemy-style parameter binding
|
||||
formatted_query = query
|
||||
for param_name, param_value in parameters.items():
|
||||
formatted_query = formatted_query.replace(f":{param_name}", str(param_value))
|
||||
|
||||
result = client.query(formatted_query)
|
||||
rows = result.result_rows
|
||||
results = [str(row) for row in rows]
|
||||
return results[:MAX_ROWS_RETURNED]
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
[build-system]
|
||||
requires = [ "hatchling",]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arcade_clickhouse"
|
||||
version = "0.3.0"
|
||||
description = "Tools to query and explore a ClickHouse database"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-mcp-server>=1.17.0,<2.0.0",
|
||||
"clickhouse-connect>=0.7.0",
|
||||
"pydantic>=2.11.7",
|
||||
"sqlalchemy>=2.0.41",
|
||||
"clickhouse-sqlalchemy>=0.2.0",
|
||||
"greenlet>=3.2.3",
|
||||
"aiochsa>=0.1.0",
|
||||
"setuptools>=80.9.0",
|
||||
]
|
||||
[[project.authors]]
|
||||
name = "evantahler"
|
||||
email = "support@arcade.dev"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"arcade-mcp[all]>=1.2.0,<2.0.0",
|
||||
"pytest>=8.3.0,<8.4.0",
|
||||
"pytest-cov>=4.0.0,<4.1.0",
|
||||
"pytest-mock>=3.11.1,<3.12.0",
|
||||
"pytest-asyncio>=0.24.0,<0.25.0",
|
||||
"mypy>=1.5.1,<1.6.0",
|
||||
"pre-commit>=3.4.0,<3.5.0",
|
||||
"tox>=4.11.1,<4.12.0",
|
||||
"ruff>=0.7.4,<0.8.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
arcade-clickhouse = "arcade_clickhouse.__main__:main"
|
||||
arcade_clickhouse = "arcade_clickhouse.__main__:main"
|
||||
|
||||
# Use local path sources for arcade libs when working locally
|
||||
[tool.uv.sources]
|
||||
arcade-mcp = { path = "../../", editable = true }
|
||||
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
files = [ "arcade_clickhouse/**/*.py",]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
check_untyped_defs = "True"
|
||||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = [ "tests",]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [ "arcade_clickhouse",]
|
||||
|
|
@ -1,369 +0,0 @@
|
|||
-- ClickHouse test database setup
|
||||
-- This file contains sample data for testing the ClickHouse toolkit
|
||||
-- Create users table
|
||||
CREATE TABLE IF NOT EXISTS default.users (
|
||||
id UInt32,
|
||||
name String,
|
||||
email String,
|
||||
password_hash String,
|
||||
created_at DateTime,
|
||||
updated_at DateTime,
|
||||
status String
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (id, created_at);
|
||||
-- Create messages table
|
||||
CREATE TABLE IF NOT EXISTS default.messages (
|
||||
id UInt32,
|
||||
body String,
|
||||
user_id UInt32,
|
||||
created_at DateTime,
|
||||
updated_at DateTime
|
||||
) ENGINE = MergeTree()
|
||||
ORDER BY (id, created_at);
|
||||
-- Insert sample data into users table
|
||||
INSERT INTO default.users (
|
||||
id,
|
||||
name,
|
||||
email,
|
||||
password_hash,
|
||||
created_at,
|
||||
updated_at,
|
||||
status
|
||||
)
|
||||
VALUES (
|
||||
1,
|
||||
'Alice',
|
||||
'alice@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
|
||||
'2024-09-01 20:49:38',
|
||||
'2024-09-02 03:49:39',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
2,
|
||||
'Bob',
|
||||
'bob@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
|
||||
'2024-09-02 17:49:23',
|
||||
'2024-09-02 17:49:23',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
3,
|
||||
'Charlie',
|
||||
'charlie@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
|
||||
'2024-09-03 10:30:15',
|
||||
'2024-09-03 10:30:15',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
4,
|
||||
'Diana',
|
||||
'diana@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123',
|
||||
'2024-09-04 14:20:30',
|
||||
'2024-09-04 14:20:30',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
5,
|
||||
'Evan',
|
||||
'evan@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456',
|
||||
'2024-09-05 09:15:45',
|
||||
'2024-09-05 09:15:45',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
6,
|
||||
'Fiona',
|
||||
'fiona@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789',
|
||||
'2024-09-06 16:45:12',
|
||||
'2024-09-06 16:45:12',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
7,
|
||||
'George',
|
||||
'george@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012',
|
||||
'2024-09-07 11:30:25',
|
||||
'2024-09-07 11:30:25',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
8,
|
||||
'Helen',
|
||||
'helen@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345',
|
||||
'2024-09-08 13:25:40',
|
||||
'2024-09-08 13:25:40',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
9,
|
||||
'Ian',
|
||||
'ian@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678',
|
||||
'2024-09-09 08:40:55',
|
||||
'2024-09-09 08:40:55',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
10,
|
||||
'Julia',
|
||||
'julia@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901',
|
||||
'2024-09-10 15:55:18',
|
||||
'2024-09-10 15:55:18',
|
||||
'active'
|
||||
);
|
||||
-- Insert sample data into messages table
|
||||
INSERT INTO default.messages (id, body, user_id, created_at, updated_at)
|
||||
VALUES (
|
||||
1,
|
||||
'Hello everyone!',
|
||||
1,
|
||||
'2025-01-10 10:00:00',
|
||||
'2025-01-10 10:00:00'
|
||||
),
|
||||
(
|
||||
2,
|
||||
'How is everyone doing today?',
|
||||
1,
|
||||
'2025-01-10 11:30:00',
|
||||
'2025-01-10 11:30:00'
|
||||
),
|
||||
(
|
||||
3,
|
||||
'Great to see you all here!',
|
||||
1,
|
||||
'2025-01-10 14:15:00',
|
||||
'2025-01-10 14:15:00'
|
||||
),
|
||||
(
|
||||
4,
|
||||
'Hi Alice! Doing well, thanks for asking.',
|
||||
2,
|
||||
'2025-01-10 11:35:00',
|
||||
'2025-01-10 11:35:00'
|
||||
),
|
||||
(
|
||||
5,
|
||||
'Anyone up for a game later?',
|
||||
2,
|
||||
'2025-01-10 16:20:00',
|
||||
'2025-01-10 16:20:00'
|
||||
),
|
||||
(
|
||||
6,
|
||||
'Count me in for the game!',
|
||||
3,
|
||||
'2025-01-10 16:25:00',
|
||||
'2025-01-10 16:25:00'
|
||||
),
|
||||
(
|
||||
7,
|
||||
'What time works for everyone?',
|
||||
3,
|
||||
'2025-01-10 16:30:00',
|
||||
'2025-01-10 16:30:00'
|
||||
),
|
||||
(
|
||||
8,
|
||||
'I can play around 8 PM',
|
||||
3,
|
||||
'2025-01-10 17:00:00',
|
||||
'2025-01-10 17:00:00'
|
||||
),
|
||||
(
|
||||
9,
|
||||
'8 PM works for me too!',
|
||||
4,
|
||||
'2025-01-10 17:05:00',
|
||||
'2025-01-10 17:05:00'
|
||||
),
|
||||
(
|
||||
10,
|
||||
'What game should we play?',
|
||||
4,
|
||||
'2025-01-10 17:10:00',
|
||||
'2025-01-10 17:10:00'
|
||||
),
|
||||
(
|
||||
11,
|
||||
'I suggest we try the new arcade game!',
|
||||
5,
|
||||
'2025-01-10 17:15:00',
|
||||
'2025-01-10 17:15:00'
|
||||
),
|
||||
(
|
||||
12,
|
||||
'It has great multiplayer features',
|
||||
5,
|
||||
'2025-01-10 17:20:00',
|
||||
'2025-01-10 17:20:00'
|
||||
),
|
||||
(
|
||||
13,
|
||||
'Perfect timing for a weekend session',
|
||||
5,
|
||||
'2025-01-10 18:00:00',
|
||||
'2025-01-10 18:00:00'
|
||||
),
|
||||
(
|
||||
26,
|
||||
'Just finished setting up the game server!',
|
||||
5,
|
||||
'2025-01-10 20:00:00',
|
||||
'2025-01-10 20:00:00'
|
||||
),
|
||||
(
|
||||
27,
|
||||
'Everyone should be able to connect now',
|
||||
5,
|
||||
'2025-01-10 20:05:00',
|
||||
'2025-01-10 20:05:00'
|
||||
),
|
||||
(
|
||||
28,
|
||||
'I added some custom maps too',
|
||||
5,
|
||||
'2025-01-10 20:10:00',
|
||||
'2025-01-10 20:10:00'
|
||||
),
|
||||
(
|
||||
29,
|
||||
'The graphics look amazing on this new version',
|
||||
5,
|
||||
'2025-01-10 20:15:00',
|
||||
'2025-01-10 20:15:00'
|
||||
),
|
||||
(
|
||||
30,
|
||||
'Hope you all enjoy the new features',
|
||||
5,
|
||||
'2025-01-10 20:20:00',
|
||||
'2025-01-10 20:20:00'
|
||||
),
|
||||
(
|
||||
31,
|
||||
'I also set up a leaderboard system',
|
||||
5,
|
||||
'2025-01-10 20:25:00',
|
||||
'2025-01-10 20:25:00'
|
||||
),
|
||||
(
|
||||
32,
|
||||
'We can track high scores now',
|
||||
5,
|
||||
'2025-01-10 20:30:00',
|
||||
'2025-01-10 20:30:00'
|
||||
),
|
||||
(
|
||||
33,
|
||||
'The game supports up to 8 players simultaneously',
|
||||
5,
|
||||
'2025-01-10 20:35:00',
|
||||
'2025-01-10 20:35:00'
|
||||
),
|
||||
(
|
||||
34,
|
||||
'I tested it earlier and it runs smoothly',
|
||||
5,
|
||||
'2025-01-10 20:40:00',
|
||||
'2025-01-10 20:40:00'
|
||||
),
|
||||
(
|
||||
35,
|
||||
'Cannot wait to see everyone online tonight!',
|
||||
5,
|
||||
'2025-01-10 20:45:00',
|
||||
'2025-01-10 20:45:00'
|
||||
),
|
||||
(
|
||||
14,
|
||||
'Sounds like fun! I love arcade games.',
|
||||
6,
|
||||
'2025-01-10 18:05:00',
|
||||
'2025-01-10 18:05:00'
|
||||
),
|
||||
(
|
||||
15,
|
||||
'Should I bring snacks?',
|
||||
6,
|
||||
'2025-01-10 18:10:00',
|
||||
'2025-01-10 18:10:00'
|
||||
),
|
||||
(
|
||||
16,
|
||||
'Snacks are always welcome!',
|
||||
7,
|
||||
'2025-01-10 18:15:00',
|
||||
'2025-01-10 18:15:00'
|
||||
),
|
||||
(
|
||||
17,
|
||||
'I can bring some drinks',
|
||||
7,
|
||||
'2025-01-10 18:20:00',
|
||||
'2025-01-10 18:20:00'
|
||||
),
|
||||
(
|
||||
18,
|
||||
'This is going to be awesome',
|
||||
7,
|
||||
'2025-01-10 19:00:00',
|
||||
'2025-01-10 19:00:00'
|
||||
),
|
||||
(
|
||||
19,
|
||||
'I agree! Cannot wait for the game night.',
|
||||
8,
|
||||
'2025-01-10 19:05:00',
|
||||
'2025-01-10 19:05:00'
|
||||
),
|
||||
(
|
||||
20,
|
||||
'Should we set up a Discord call?',
|
||||
8,
|
||||
'2025-01-10 19:10:00',
|
||||
'2025-01-10 19:10:00'
|
||||
),
|
||||
(
|
||||
21,
|
||||
'Discord would be perfect for voice chat',
|
||||
9,
|
||||
'2025-01-10 19:15:00',
|
||||
'2025-01-10 19:15:00'
|
||||
),
|
||||
(
|
||||
22,
|
||||
'I will create a server for us',
|
||||
9,
|
||||
'2025-01-10 19:20:00',
|
||||
'2025-01-10 19:20:00'
|
||||
),
|
||||
(
|
||||
23,
|
||||
'Link will be shared in a few minutes',
|
||||
9,
|
||||
'2025-01-10 19:25:00',
|
||||
'2025-01-10 19:25:00'
|
||||
),
|
||||
(
|
||||
24,
|
||||
'Thanks Ian! You are the best.',
|
||||
10,
|
||||
'2025-01-10 19:30:00',
|
||||
'2025-01-10 19:30:00'
|
||||
),
|
||||
(
|
||||
25,
|
||||
'See you all at 8 PM!',
|
||||
10,
|
||||
'2025-01-10 19:35:00',
|
||||
'2025-01-10 19:35:00'
|
||||
);
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
import os
|
||||
from os import environ
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from arcade_clickhouse.tools.clickhouse import (
|
||||
DatabaseEngine,
|
||||
discover_schemas,
|
||||
discover_tables,
|
||||
execute_select_query,
|
||||
get_table_schema,
|
||||
)
|
||||
from arcade_mcp_server import Context
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
|
||||
CLICKHOUSE_DATABASE_CONNECTION_STRING = (
|
||||
environ.get("TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING")
|
||||
or "clickhouse+native://localhost:9000/default"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = MagicMock(spec=Context)
|
||||
context.get_secret = MagicMock(return_value=CLICKHOUSE_DATABASE_CONNECTION_STRING)
|
||||
return context
|
||||
|
||||
|
||||
# before the tests, restore the database from the dump
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def restore_database():
|
||||
import clickhouse_connect
|
||||
|
||||
# Create client for database setup
|
||||
client = clickhouse_connect.get_client(host="localhost", port=8123)
|
||||
|
||||
# Clear existing tables first to avoid duplicates
|
||||
client.command("DROP TABLE IF EXISTS default.messages")
|
||||
client.command("DROP TABLE IF EXISTS default.users")
|
||||
|
||||
# Read and execute the dump file
|
||||
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
|
||||
queries = f.read().split(";")
|
||||
for query in queries:
|
||||
if query.strip():
|
||||
client.command(query)
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def cleanup_engines():
|
||||
"""Clean up database engines after each test to prevent connection leaks."""
|
||||
yield
|
||||
# Clean up all cached engines after each test
|
||||
await DatabaseEngine.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_schemas(mock_context) -> None:
|
||||
assert await discover_schemas(mock_context) == ["default"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_tables(mock_context) -> None:
|
||||
tables = await discover_tables(mock_context)
|
||||
assert sorted(tables) == ["messages", "users"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema(mock_context) -> None:
|
||||
users_schema = await get_table_schema(mock_context, "default", "users")
|
||||
expected_users = [
|
||||
"id: UInt32 (PRIMARY KEY)",
|
||||
"name: String",
|
||||
"email: String",
|
||||
"password_hash: String",
|
||||
"created_at: DateTime (PRIMARY KEY)",
|
||||
"updated_at: DateTime",
|
||||
"status: String",
|
||||
]
|
||||
assert users_schema == expected_users
|
||||
|
||||
messages_schema = await get_table_schema(mock_context, "default", "messages")
|
||||
expected_messages = [
|
||||
"id: UInt32 (PRIMARY KEY)",
|
||||
"body: String",
|
||||
"user_id: UInt32",
|
||||
"created_at: DateTime (PRIMARY KEY)",
|
||||
"updated_at: DateTime",
|
||||
]
|
||||
assert messages_schema == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query(mock_context) -> None:
|
||||
# Test specific user query with limit
|
||||
result1 = await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="id, name, email",
|
||||
from_clause="users",
|
||||
where_clause="id = 1",
|
||||
limit=1,
|
||||
)
|
||||
assert result1 == ["(1, 'Alice', 'alice@example.com')"]
|
||||
|
||||
# Test query with offset
|
||||
result2 = await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="id, name, email",
|
||||
from_clause="users",
|
||||
order_by_clause="id",
|
||||
limit=1,
|
||||
offset=1,
|
||||
)
|
||||
assert result2 == ["(2, 'Bob', 'bob@example.com')"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_keywords(mock_context) -> None:
|
||||
result = await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="SELECT id, name, email",
|
||||
from_clause="FROM users",
|
||||
limit=1,
|
||||
)
|
||||
assert result == ["(1, 'Alice', 'alice@example.com')"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_join(mock_context) -> None:
|
||||
result = await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="u.id, u.name, u.email, m.id, m.body",
|
||||
from_clause="users u",
|
||||
join_clause="messages m ON u.id = m.user_id",
|
||||
limit=1,
|
||||
)
|
||||
assert result == ["(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_group_by(mock_context) -> None:
|
||||
result = await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="u.name, COUNT(m.id) AS message_count",
|
||||
from_clause="messages m",
|
||||
join_clause="users u ON m.user_id = u.id",
|
||||
group_by_clause="u.name",
|
||||
order_by_clause="message_count DESC",
|
||||
limit=2,
|
||||
)
|
||||
assert result == [
|
||||
"('Evan', 13)",
|
||||
"('Alice', 3)",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_no_results(mock_context) -> None:
|
||||
# does not raise an error
|
||||
assert (
|
||||
await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="id, name, email",
|
||||
from_clause="users",
|
||||
where_clause="id = 9999999999",
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_problem(mock_context) -> None:
|
||||
# 'foo' is not a valid id
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="*",
|
||||
from_clause="users",
|
||||
where_clause="id = 'foo'",
|
||||
)
|
||||
assert "Do not use * in the select clause" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_rejects_non_select(mock_context) -> None:
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
|
||||
from_clause="users",
|
||||
)
|
||||
assert "Only SELECT queries are allowed" in str(e.value)
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
docker run -d --name some-clickhouse-server --ulimit nofile=262144:262144 -p 8123:8123 -p 8443:8443 -p 9000:9000 yandex/clickhouse-server
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
files: ^.*/linkedin/.*
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: "v4.4.0"
|
||||
hooks:
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
target-version = "py310"
|
||||
line-length = 100
|
||||
fix = true
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
# flake8-2020
|
||||
"YTT",
|
||||
# flake8-bandit
|
||||
"S",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-builtins
|
||||
"A",
|
||||
# flake8-comprehensions
|
||||
"C4",
|
||||
# flake8-debugger
|
||||
"T10",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# mccabe
|
||||
"C90",
|
||||
# pycodestyle
|
||||
"E", "W",
|
||||
# pyflakes
|
||||
"F",
|
||||
# pygrep-hooks
|
||||
"PGH",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# ruff
|
||||
"RUF",
|
||||
# tryceratops
|
||||
"TRY",
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"*" = ["TRY003", "B904"]
|
||||
"**/tests/*" = ["S101", "E501"]
|
||||
"**/evals/*" = ["S101", "E501"]
|
||||
|
||||
[format]
|
||||
preview = true
|
||||
skip-magic-trailing-comma = false
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025, Arcade AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ github Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the uv environment and install all packages with dependencies
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras --no-sources
|
||||
@if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: install-local
|
||||
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras
|
||||
@if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
uv build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
@uv run --no-sources coverage report
|
||||
@echo "Generating coverage report"
|
||||
@uv run --no-sources coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
uv version --no-sources --bump patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@if [ -f .pre-commit-config.yaml ]; then\
|
||||
echo "🚀 Linting code: Running pre-commit";\
|
||||
uv run --no-sources pre-commit run -a;\
|
||||
fi
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@uv run --no-sources mypy --config-file=pyproject.toml
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
import sys
|
||||
from typing import cast
|
||||
|
||||
from arcade_mcp_server import MCPApp
|
||||
from arcade_mcp_server.mcp_app import TransportType
|
||||
|
||||
import arcade_linkedin
|
||||
|
||||
app = MCPApp(
|
||||
name="LinkedIn",
|
||||
instructions=(
|
||||
"Use this server when you need to interact with LinkedIn to help users "
|
||||
"create and share posts on their LinkedIn profile."
|
||||
),
|
||||
)
|
||||
|
||||
app.add_tools_from_module(arcade_linkedin)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
|
||||
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
|
||||
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
|
||||
app.run(transport=cast(TransportType, transport), host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1 +0,0 @@
|
|||
LINKEDIN_BASE_URL = "https://api.linkedin.com/v2"
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import Context, tool
|
||||
from arcade_mcp_server.auth import LinkedIn
|
||||
from arcade_mcp_server.exceptions import ToolExecutionError
|
||||
from arcade_mcp_server.metadata import (
|
||||
Behavior,
|
||||
Classification,
|
||||
Operation,
|
||||
ServiceDomain,
|
||||
ToolMetadata,
|
||||
)
|
||||
|
||||
from arcade_linkedin.tools.utils import _handle_linkedin_api_error, _send_linkedin_request
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=LinkedIn(
|
||||
scopes=["w_member_social"],
|
||||
),
|
||||
metadata=ToolMetadata(
|
||||
classification=Classification(
|
||||
service_domains=[ServiceDomain.SOCIAL_MEDIA],
|
||||
),
|
||||
behavior=Behavior(
|
||||
operations=[Operation.CREATE],
|
||||
read_only=False,
|
||||
destructive=False,
|
||||
idempotent=False,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def create_text_post(
|
||||
context: Context,
|
||||
text: Annotated[str, "The text content of the post"],
|
||||
) -> Annotated[str, "URL of the shared post"]:
|
||||
"""Share a new text post to LinkedIn."""
|
||||
endpoint = "/ugcPosts"
|
||||
|
||||
# The LinkedIn user ID is required to create a post, even though we're using
|
||||
# the user's access token.
|
||||
# Arcade Engine gets the current user's info from LinkedIn and automatically
|
||||
# populates context.authorization.user_info.
|
||||
# LinkedIn calls the user ID "sub" in their user_info data payload. See:
|
||||
# https://learn.microsoft.com/en-us/linkedin/consumer/integrations/self-serve/sign-in-with-linkedin-v2#api-request-to-retreive-member-details
|
||||
user_id = context.authorization.user_info.get("sub") if context.authorization else None
|
||||
|
||||
if not user_id:
|
||||
raise ToolExecutionError(
|
||||
"User ID not found.",
|
||||
developer_message="User ID not found in `context.authorization.user_info.sub`",
|
||||
)
|
||||
|
||||
author_id = f"urn:li:person:{user_id}"
|
||||
payload = {
|
||||
"author": author_id,
|
||||
"lifecycleState": "PUBLISHED",
|
||||
"specificContent": {
|
||||
"com.linkedin.ugc.ShareContent": {
|
||||
"shareCommentary": {"text": text},
|
||||
"shareMediaCategory": "NONE",
|
||||
}
|
||||
},
|
||||
"visibility": {"com.linkedin.ugc.MemberNetworkVisibility": "PUBLIC"},
|
||||
}
|
||||
|
||||
response = await _send_linkedin_request(context, "POST", endpoint, json_data=payload)
|
||||
|
||||
if response.status_code >= 200 and response.status_code < 300:
|
||||
share_id = response.json().get("id")
|
||||
return f"https://www.linkedin.com/feed/update/{share_id}/"
|
||||
|
||||
_handle_linkedin_api_error(response)
|
||||
|
||||
return ""
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
import httpx
|
||||
from arcade_mcp_server import Context
|
||||
from arcade_mcp_server.exceptions import ToolExecutionError
|
||||
|
||||
from arcade_linkedin.tools.constants import LINKEDIN_BASE_URL
|
||||
|
||||
|
||||
async def _send_linkedin_request(
|
||||
context: Context,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
params: dict | None = None,
|
||||
json_data: dict | None = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Send an asynchronous request to the LinkedIn API.
|
||||
|
||||
Args:
|
||||
context: The tool context containing the authorization token.
|
||||
method: The HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||
endpoint: The API endpoint path (e.g., "/ugcPosts").
|
||||
params: Query parameters to include in the request.
|
||||
json_data: JSON data to include in the request body.
|
||||
|
||||
Returns:
|
||||
The response object from the API request.
|
||||
|
||||
Raises:
|
||||
ToolExecutionError: If the request fails for any reason.
|
||||
"""
|
||||
url = f"{LINKEDIN_BASE_URL}{endpoint}"
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.request(
|
||||
method, url, headers=headers, params=params, json=json_data
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.RequestError as e:
|
||||
raise ToolExecutionError(f"Failed to send request to LinkedIn API: {e}")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _handle_linkedin_api_error(response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle errors from the LinkedIn API by mapping common status codes to ToolExecutionErrors.
|
||||
|
||||
Args:
|
||||
response: The response object from the API request.
|
||||
|
||||
Raises:
|
||||
ToolExecutionError: If the response contains an error status code.
|
||||
"""
|
||||
status_code_map = {
|
||||
401: ToolExecutionError("Unauthorized: Invalid or expired token"),
|
||||
403: ToolExecutionError("Forbidden: User does not have Spotify Premium"),
|
||||
429: ToolExecutionError("Too Many Requests: Rate limit exceeded"),
|
||||
}
|
||||
|
||||
if response.status_code in status_code_map:
|
||||
raise status_code_map[response.status_code]
|
||||
elif response.status_code >= 400:
|
||||
raise ToolExecutionError(f"Error: {response.status_code} - {response.text}")
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from arcade_mcp_server import Context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_context():
|
||||
"""Fixture for the tool Context with mock authorization."""
|
||||
context = MagicMock(spec=Context)
|
||||
authorization = MagicMock()
|
||||
authorization.token = "test_token" # noqa: S105
|
||||
authorization.user_info = {"sub": "test_user"}
|
||||
context.authorization = authorization
|
||||
return context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client(mocker):
|
||||
"""Fixture to mock the httpx.AsyncClient."""
|
||||
# Mock the AsyncClient context manager
|
||||
mock_client = mocker.patch("httpx.AsyncClient", autospec=True)
|
||||
async_mock_client = mock_client.return_value.__aenter__.return_value
|
||||
return async_mock_client
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
from arcade_core import ToolCatalog
|
||||
from arcade_evals import (
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
import arcade_linkedin
|
||||
from arcade_linkedin.tools.share import create_text_post
|
||||
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.85,
|
||||
warn_threshold=0.95,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_module(arcade_linkedin)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def linkedin_eval_suite() -> EvalSuite:
|
||||
suite = EvalSuite(
|
||||
name="LinkedIn Tools Evaluation",
|
||||
system_message="You are an AI assistant with access to LinkedIn tools. Use them to help the user with their tasks.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Run code",
|
||||
user_message="post this transcription to linkedin. there may be some things that you need to clean up since it was spoken.: 'It is with great pleasure that I announce that I am now a member of the LinkedIn community! I'd like to thank the LinkedIn team for their support and encouragement in my journey to success. hash tag Y2K'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=create_text_post,
|
||||
args={
|
||||
"text": "It is with great pleasure that I announce that I am now a member of the LinkedIn community! I'd like to thank the LinkedIn team for their support and encouragement in my journey to success. #Y2K",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="text", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
[build-system]
|
||||
requires = [ "hatchling",]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arcade_linkedin"
|
||||
version = "0.3.0"
|
||||
description = "Arcade.dev LLM tools for LinkedIn"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-mcp-server>=1.17.0,<2.0.0",
|
||||
"httpx>=0.27.2,<1.0.0",
|
||||
]
|
||||
[[project.authors]]
|
||||
name = "Arcade"
|
||||
email = "dev@arcade.dev"
|
||||
|
||||
[project.scripts]
|
||||
arcade-linkedin = "arcade_linkedin.__main__:main"
|
||||
arcade_linkedin = "arcade_linkedin.__main__:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"arcade-mcp[all]>=1.2.0,<2.0.0",
|
||||
"pytest>=8.3.0,<8.4.0",
|
||||
"pytest-cov>=4.0.0,<4.1.0",
|
||||
"pytest-asyncio>=0.24.0,<0.25.0",
|
||||
"pytest-mock>=3.11.1,<3.12.0",
|
||||
"mypy>=1.5.1,<1.6.0",
|
||||
"pre-commit>=3.4.0,<3.5.0",
|
||||
"tox>=4.11.1,<4.12.0",
|
||||
"ruff>=0.7.4,<0.8.0",
|
||||
]
|
||||
|
||||
# Use local path sources for arcade libs when working locally
|
||||
[tool.uv.sources]
|
||||
arcade-mcp = {path = "../../", editable = true}
|
||||
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
files = [ "arcade_linkedin/**/*.py",]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
check_untyped_defs = "True"
|
||||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = [ "tests",]
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [ "arcade_linkedin",]
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from arcade_mcp_server.exceptions import ToolExecutionError
|
||||
|
||||
from arcade_linkedin.tools.share import create_text_post
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_text_post_success(tool_context, mock_httpx_client):
|
||||
"""Test successful creation of a LinkedIn text post."""
|
||||
# Mock response for a successful post creation
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"id": "1234567890"}
|
||||
# Ensure the mock is awaited properly
|
||||
mock_httpx_client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
post_text = "Hello, LinkedIn!"
|
||||
result = await create_text_post(tool_context, post_text)
|
||||
|
||||
expected_url = "https://www.linkedin.com/feed/update/1234567890/"
|
||||
assert result == expected_url
|
||||
mock_httpx_client.request.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_text_post_no_user_id(tool_context):
|
||||
"""Test error when user ID is not found in the context."""
|
||||
# Simulate missing user ID in the context
|
||||
tool_context.authorization.user_info = {}
|
||||
|
||||
post_text = "Hello, LinkedIn!"
|
||||
with pytest.raises(ToolExecutionError, match="User ID not found"):
|
||||
await create_text_post(tool_context, post_text)
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
files: ^.*/math/.*
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: "v4.4.0"
|
||||
hooks:
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
target-version = "py310"
|
||||
line-length = 100
|
||||
fix = true
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
# flake8-2020
|
||||
"YTT",
|
||||
# flake8-bandit
|
||||
"S",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-builtins
|
||||
"A",
|
||||
# flake8-comprehensions
|
||||
"C4",
|
||||
# flake8-debugger
|
||||
"T10",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# mccabe
|
||||
"C90",
|
||||
# pycodestyle
|
||||
"E", "W",
|
||||
# pyflakes
|
||||
"F",
|
||||
# pygrep-hooks
|
||||
"PGH",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# ruff
|
||||
"RUF",
|
||||
# tryceratops
|
||||
"TRY",
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"*" = ["TRY003", "B904"]
|
||||
"**/tests/*" = ["S101", "E501"]
|
||||
"**/evals/*" = ["S101", "E501"]
|
||||
|
||||
|
||||
[format]
|
||||
preview = true
|
||||
skip-magic-trailing-comma = false
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025, Arcade AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ github Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the uv environment and install all packages with dependencies
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras --no-sources
|
||||
@if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: install-local
|
||||
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras
|
||||
@if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
uv build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
@uv run --no-sources coverage report
|
||||
@echo "Generating coverage report"
|
||||
@uv run --no-sources coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
uv version --no-sources --bump patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@if [ -f .pre-commit-config.yaml ]; then\
|
||||
echo "🚀 Linting code: Running pre-commit";\
|
||||
uv run --no-sources pre-commit run -a;\
|
||||
fi
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@uv run --no-sources mypy --config-file=pyproject.toml
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import sys
|
||||
from typing import cast
|
||||
|
||||
from arcade_mcp_server import MCPApp
|
||||
from arcade_mcp_server.mcp_app import TransportType
|
||||
|
||||
import arcade_math
|
||||
|
||||
app = MCPApp(
|
||||
name="Math",
|
||||
instructions=(
|
||||
"Use this server when you need to perform mathematical calculations to help users "
|
||||
"with arithmetic, trigonometry, statistics, exponents, rounding, and other math operations."
|
||||
),
|
||||
)
|
||||
|
||||
app.add_tools_from_module(arcade_math)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
|
||||
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
|
||||
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
|
||||
|
||||
app.run(transport=cast(TransportType, transport), host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
from arcade_math.tools.arithmetic import (
|
||||
add,
|
||||
divide,
|
||||
mod,
|
||||
multiply,
|
||||
subtract,
|
||||
sum_list,
|
||||
sum_range,
|
||||
)
|
||||
from arcade_math.tools.exponents import (
|
||||
log,
|
||||
power,
|
||||
)
|
||||
from arcade_math.tools.miscellaneous import (
|
||||
abs_val,
|
||||
factorial,
|
||||
sqrt,
|
||||
)
|
||||
from arcade_math.tools.random import (
|
||||
generate_random_float,
|
||||
generate_random_int,
|
||||
)
|
||||
from arcade_math.tools.rational import (
|
||||
gcd,
|
||||
lcm,
|
||||
)
|
||||
from arcade_math.tools.rounding import (
|
||||
ceil,
|
||||
floor,
|
||||
round_num,
|
||||
)
|
||||
from arcade_math.tools.statistics import (
|
||||
avg,
|
||||
median,
|
||||
)
|
||||
from arcade_math.tools.trigonometry import (
|
||||
deg_to_rad,
|
||||
rad_to_deg,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"abs_val",
|
||||
"add",
|
||||
"avg",
|
||||
"ceil",
|
||||
"deg_to_rad",
|
||||
"divide",
|
||||
"factorial",
|
||||
"floor",
|
||||
"gcd",
|
||||
"generate_random_float",
|
||||
"generate_random_int",
|
||||
"lcm",
|
||||
"log",
|
||||
"median",
|
||||
"mod",
|
||||
"multiply",
|
||||
"power",
|
||||
"rad_to_deg",
|
||||
"round_num",
|
||||
"sqrt",
|
||||
"subtract",
|
||||
"sum_list",
|
||||
"sum_range",
|
||||
]
|
||||
|
|
@ -1,161 +0,0 @@
|
|||
import decimal
|
||||
from decimal import Decimal
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import tool
|
||||
from arcade_mcp_server.metadata import Behavior, ToolMetadata
|
||||
|
||||
decimal.getcontext().prec = 100
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def add(
|
||||
a: Annotated[str, "The first number as a string"],
|
||||
b: Annotated[str, "The second number as a string"],
|
||||
) -> Annotated[str, "The sum of the two numbers as a string"]:
|
||||
"""
|
||||
Add two numbers together
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
a_decimal = Decimal(a)
|
||||
b_decimal = Decimal(b)
|
||||
return str(a_decimal + b_decimal)
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def subtract(
|
||||
a: Annotated[str, "The first number as a string"],
|
||||
b: Annotated[str, "The second number as a string"],
|
||||
) -> Annotated[str, "The difference of the two numbers as a string"]:
|
||||
"""
|
||||
Subtract two numbers
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
a_decimal = Decimal(a)
|
||||
b_decimal = Decimal(b)
|
||||
return str(a_decimal - b_decimal)
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def multiply(
|
||||
a: Annotated[str, "The first number as a string"],
|
||||
b: Annotated[str, "The second number as a string"],
|
||||
) -> Annotated[str, "The product of the two numbers as a string"]:
|
||||
"""
|
||||
Multiply two numbers together
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
a_decimal = Decimal(a)
|
||||
b_decimal = Decimal(b)
|
||||
return str(a_decimal * b_decimal)
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def divide(
|
||||
a: Annotated[str, "The first number as a string"],
|
||||
b: Annotated[str, "The second number as a string"],
|
||||
) -> Annotated[str, "The quotient of the two numbers as a string"]:
|
||||
"""
|
||||
Divide two numbers
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
a_decimal = Decimal(a)
|
||||
b_decimal = Decimal(b)
|
||||
return str(a_decimal / b_decimal)
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def sum_list(
|
||||
numbers: Annotated[list[str], "The list of numbers as strings"],
|
||||
) -> Annotated[str, "The sum of the numbers in the list as a string"]:
|
||||
"""
|
||||
Sum all numbers in a list
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(sum([Decimal(n) for n in numbers]))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def sum_range(
|
||||
start: Annotated[str, "The start of the range to sum as a string"],
|
||||
end: Annotated[str, "The end of the range to sum as a string"],
|
||||
) -> Annotated[str, "The sum of the numbers in the list as a string"]:
|
||||
"""
|
||||
Sum all numbers from start through end
|
||||
"""
|
||||
return str(sum(list(range(int(start), int(end) + 1))))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def mod(
|
||||
a: Annotated[str, "The dividend as a string"],
|
||||
b: Annotated[str, "The divisor as a string"],
|
||||
) -> Annotated[str, "The remainder after dividing a by b as a string"]:
|
||||
"""
|
||||
Calculate the remainder (modulus) of one number divided by another
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(Decimal(a) % Decimal(b))
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
import decimal
|
||||
import math
|
||||
from decimal import Decimal
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import tool
|
||||
from arcade_mcp_server.metadata import Behavior, ToolMetadata
|
||||
|
||||
decimal.getcontext().prec = 100
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def log(
|
||||
a: Annotated[str, "The number to take the logarithm of as a string"],
|
||||
base: Annotated[str, "The logarithmic base as a string"],
|
||||
) -> Annotated[str, "The logarithm of the number with the specified base as a string"]:
|
||||
"""
|
||||
Calculate the logarithm of a number with a given base
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(math.log(Decimal(a), Decimal(base)))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def power(
|
||||
a: Annotated[str, "The base number as a string"],
|
||||
b: Annotated[str, "The exponent as a string"],
|
||||
) -> Annotated[str, "The result of raising a to the power of b as a string"]:
|
||||
"""
|
||||
Calculate one number raised to the power of another
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(Decimal(a) ** Decimal(b))
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
import decimal
|
||||
import math
|
||||
from decimal import Decimal
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import tool
|
||||
from arcade_mcp_server.metadata import Behavior, ToolMetadata
|
||||
|
||||
decimal.getcontext().prec = 100
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def abs_val(
|
||||
a: Annotated[str, "The number as a string"],
|
||||
) -> Annotated[str, "The absolute value of the number as a string"]:
|
||||
"""
|
||||
Calculate the absolute value of a number
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(abs(Decimal(a)))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def factorial(
|
||||
a: Annotated[str, "The non-negative integer to compute the factorial for as a string"],
|
||||
) -> Annotated[str, "The factorial of the number as a string"]:
|
||||
"""
|
||||
Compute the factorial of a non-negative integer
|
||||
Returns "1" for "0"
|
||||
"""
|
||||
return str(math.factorial(int(a)))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def sqrt(
|
||||
a: Annotated[str, "The number to square root as a string"],
|
||||
) -> Annotated[str, "The square root of the number as a string"]:
|
||||
"""
|
||||
Get the square root of a number
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
a_decimal = Decimal(a)
|
||||
return str(a_decimal.sqrt())
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
import random
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import tool
|
||||
from arcade_mcp_server.metadata import Behavior, ToolMetadata
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=False,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def generate_random_int(
|
||||
min_value: Annotated[str, "The minimum value of the random integer as a string"],
|
||||
max_value: Annotated[str, "The maximum value of the random integer as a string"],
|
||||
seed: Annotated[
|
||||
str | None,
|
||||
"The seed for the random number generator as a string."
|
||||
" If None, the current system time is used.",
|
||||
] = None,
|
||||
) -> Annotated[str, "A random integer between min_value and max_value as a string"]:
|
||||
"""Generate a random integer between min_value and max_value (inclusive)."""
|
||||
if seed is not None:
|
||||
random.seed(int(seed))
|
||||
|
||||
return str(random.randint(int(min_value), int(max_value))) # noqa: S311
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=False,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def generate_random_float(
|
||||
min_value: Annotated[str, "The minimum value of the random float as a string"],
|
||||
max_value: Annotated[str, "The maximum value of the random float as a string"],
|
||||
seed: Annotated[
|
||||
str | None,
|
||||
"The seed for the random number generator as a string."
|
||||
" If None, the current system time is used.",
|
||||
] = None,
|
||||
) -> Annotated[str, "A random float between min_value and max_value as a string"]:
|
||||
"""Generate a random float between min_value and max_value."""
|
||||
if seed is not None:
|
||||
random.seed(int(seed))
|
||||
|
||||
return str(random.uniform(float(min_value), float(max_value))) # noqa: S311
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
import math
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import tool
|
||||
from arcade_mcp_server.metadata import Behavior, ToolMetadata
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def gcd(
|
||||
a: Annotated[str, "First integer as a string"],
|
||||
b: Annotated[str, "Second integer as a string"],
|
||||
) -> Annotated[str, "The greatest common divisor of a and b as a string"]:
|
||||
"""
|
||||
Calculate the greatest common divisor (GCD) of two integers.
|
||||
"""
|
||||
return str(math.gcd(int(a), int(b)))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def lcm(
|
||||
a: Annotated[str, "First integer as a string"],
|
||||
b: Annotated[str, "Second integer as a string"],
|
||||
) -> Annotated[str, "The least common multiple of a and b as a string"]:
|
||||
"""
|
||||
Calculate the least common multiple (LCM) of two integers.
|
||||
Returns "0" if either integer is 0.
|
||||
"""
|
||||
a_int, b_int = int(a), int(b)
|
||||
if a_int == 0 or b_int == 0:
|
||||
return "0"
|
||||
return str(abs(a_int * b_int) // math.gcd(a_int, b_int))
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
import decimal
|
||||
import math
|
||||
from decimal import Decimal
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import tool
|
||||
from arcade_mcp_server.metadata import Behavior, ToolMetadata
|
||||
|
||||
decimal.getcontext().prec = 100
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def ceil(
|
||||
a: Annotated[str, "The number to round up as a string"],
|
||||
) -> Annotated[str, "The smallest integer greater than or equal to the number as a string"]:
|
||||
"""
|
||||
Return the ceiling of a number
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(math.ceil(Decimal(a)))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def floor(
|
||||
a: Annotated[str, "The number to round down as a string"],
|
||||
) -> Annotated[str, "The largest integer less than or equal to the number as a string"]:
|
||||
"""
|
||||
Return the floor of a number
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(math.floor(Decimal(a)))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def round_num(
|
||||
value: Annotated[str, "The number to round as a string"],
|
||||
ndigits: Annotated[str, "The number of digits after the decimal point as a string"],
|
||||
) -> Annotated[str, "The number rounded to the specified number of digits as a string"]:
|
||||
"""
|
||||
Round a number to a specified number of positive digits
|
||||
"""
|
||||
ndigits_int = int(ndigits)
|
||||
if ndigits_int >= 0:
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(round(Decimal(value), int(ndigits_int)))
|
||||
# cast value from str -> float -> int here because rounding with negative
|
||||
# decimals is only useful for weird math
|
||||
return str(round(int(float(value)), int(ndigits_int)))
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
import decimal
|
||||
from decimal import Decimal
|
||||
from statistics import median as stats_median
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import tool
|
||||
from arcade_mcp_server.metadata import Behavior, ToolMetadata
|
||||
|
||||
decimal.getcontext().prec = 100
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def avg(
|
||||
numbers: Annotated[list[str], "The list of numbers as strings"],
|
||||
) -> Annotated[str, "The average (mean) of the numbers in the list as a string"]:
|
||||
"""
|
||||
Calculate the average (mean) of a list of numbers.
|
||||
Returns "0.0" if the list is empty.
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
d_numbers = [Decimal(n) for n in numbers]
|
||||
return str(sum(d_numbers) / len(d_numbers)) if d_numbers else "0.0"
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def median(
|
||||
numbers: Annotated[list[str], "A list of numbers as strings"],
|
||||
) -> Annotated[str, "The median value of the numbers in the list as a string"]:
|
||||
"""
|
||||
Calculate the median of a list of numbers.
|
||||
Returns "0.0" if the list is empty.
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
d_numbers = [Decimal(n) for n in numbers]
|
||||
return str(stats_median(d_numbers)) if d_numbers else "0.0"
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
import decimal
|
||||
import math
|
||||
from decimal import Decimal
|
||||
from typing import Annotated
|
||||
|
||||
from arcade_mcp_server import tool
|
||||
from arcade_mcp_server.metadata import Behavior, ToolMetadata
|
||||
|
||||
decimal.getcontext().prec = 100
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def deg_to_rad(
|
||||
degrees: Annotated[str, "Angle in degrees as a string"],
|
||||
) -> Annotated[str, "Angle in radians as a string"]:
|
||||
"""
|
||||
Convert an angle from degrees to radians.
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(math.radians(Decimal(degrees)))
|
||||
|
||||
|
||||
@tool(
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
def rad_to_deg(
|
||||
radians: Annotated[str, "Angle in radians as a string"],
|
||||
) -> Annotated[str, "Angle in degrees as a string"]:
|
||||
"""
|
||||
Convert an angle from radians to degrees.
|
||||
"""
|
||||
# Use Decimal for arbitrary precision
|
||||
return str(math.degrees(Decimal(radians)))
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from arcade_core import ToolCatalog
|
||||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
import arcade_math
|
||||
from arcade_math.tools.arithmetic import (
|
||||
add,
|
||||
divide,
|
||||
mod,
|
||||
multiply,
|
||||
subtract,
|
||||
sum_list,
|
||||
sum_range,
|
||||
)
|
||||
from arcade_math.tools.exponents import (
|
||||
log,
|
||||
power,
|
||||
)
|
||||
from arcade_math.tools.miscellaneous import (
|
||||
abs_val,
|
||||
factorial,
|
||||
sqrt,
|
||||
)
|
||||
from arcade_math.tools.rational import (
|
||||
gcd,
|
||||
lcm,
|
||||
)
|
||||
from arcade_math.tools.rounding import (
|
||||
ceil,
|
||||
floor,
|
||||
round_num,
|
||||
)
|
||||
from arcade_math.tools.statistics import (
|
||||
avg,
|
||||
median,
|
||||
)
|
||||
from arcade_math.tools.trigonometry import (
|
||||
deg_to_rad,
|
||||
rad_to_deg,
|
||||
)
|
||||
|
||||
# Type alias for test case tuples: (function, prompt_template, params)
|
||||
TestCase = tuple[Callable[..., Any], str, dict[str, Any]]
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.85,
|
||||
warn_threshold=0.95,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_module(arcade_math)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def math_eval_suite() -> EvalSuite:
|
||||
suite = EvalSuite(
|
||||
name="Math Tools Evaluation",
|
||||
system_message="You're an AI assistant with access to math tools. Use them to help the user with their math-related tasks.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
list_param = ["1", "2", "3", "4", "5"]
|
||||
funcs_to_expression_and_params: list[TestCase] = [
|
||||
# unary
|
||||
(sqrt, "What's the square root of {a}?", {"a": "25"}),
|
||||
(abs_val, "What's the absolute value of {a}?", {"a": "-10"}),
|
||||
(factorial, "What's the factorial of {a}?", {"a": "5"}),
|
||||
(deg_to_rad, "Convert {degrees} from degrees to radians", {"degrees": "180"}),
|
||||
(rad_to_deg, "Convert {radians} from radias to degrees", {"radians": "3.14"}),
|
||||
(ceil, "Compute the ceiling of {a}", {"a": "3.14"}),
|
||||
(floor, "Compute the floor of {a}", {"a": "3.14"}),
|
||||
# binary
|
||||
(add, "Add {a} and {b}", {"a": "12345", "b": "987654321"}),
|
||||
(subtract, "Subtract {b} from {a}", {"a": "987654321", "b": "12345"}),
|
||||
(multiply, "Multiply {a} and {b}", {"a": "12345", "b": "567890"}),
|
||||
(divide, "What is {a} divided by {b}?", {"a": "1234123479", "b": "123"}),
|
||||
(
|
||||
sum_range,
|
||||
"What's the sum of all numbers from {start} to {end}?",
|
||||
{"start": "10", "end": "345"},
|
||||
),
|
||||
(mod, "What's the remainder of dividing {a} by {b}?", {"a": "234", "b": "17"}),
|
||||
(power, "Raise {a} to the power of {b}", {"a": "2", "b": "8"}),
|
||||
(log, "What's the logarithm of {a} with base {base}?", {"a": "8", "base": "2"}),
|
||||
(
|
||||
round_num,
|
||||
"Round {value} to {ndigits} decimal places",
|
||||
{"value": "12.23746234", "ndigits": "3"},
|
||||
),
|
||||
(gcd, "Find the greatest common divisor of {a} and {b}", {"a": "50", "b": "10"}),
|
||||
(lcm, "FInd the least common multiple of {a} and {b}", {"a": "7", "b": "13"}),
|
||||
# n-nary
|
||||
(
|
||||
sum_list,
|
||||
f"Calculate the sum of these numbers: {' '.join(list_param)}",
|
||||
{"numbers": list_param},
|
||||
),
|
||||
(
|
||||
avg,
|
||||
f"Find the average of these numbers: {' '.join(list_param)}",
|
||||
{"numbers": list_param},
|
||||
),
|
||||
(
|
||||
median,
|
||||
f"Find the median of these numbers: {' '.join(list_param)}",
|
||||
{"numbers": list_param},
|
||||
),
|
||||
]
|
||||
|
||||
for func, expression, params in funcs_to_expression_and_params:
|
||||
parametrized_expression = expression.format(**params)
|
||||
num_params = len(params)
|
||||
suite.add_case(
|
||||
name=parametrized_expression,
|
||||
user_message=parametrized_expression,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=func,
|
||||
args=params,
|
||||
)
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[BinaryCritic(critic_field=param, weight=1.0 / num_params) for param in params],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
[build-system]
|
||||
requires = [ "hatchling",]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arcade_math"
|
||||
version = "1.2.0"
|
||||
description = "Arcade.dev LLM tools for doing math"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-mcp-server>=1.17.0,<2.0.0",
|
||||
]
|
||||
[[project.authors]]
|
||||
name = "Arcade"
|
||||
email = "dev@arcade.dev"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"arcade-mcp[all]>=1.2.0,<2.0.0",
|
||||
"pytest>=8.3.0,<8.4.0",
|
||||
"pytest-cov>=4.0.0,<4.1.0",
|
||||
"pytest-asyncio>=0.24.0,<0.25.0",
|
||||
"pytest-mock>=3.11.1,<3.12.0",
|
||||
"mypy>=1.5.1,<1.6.0",
|
||||
"pre-commit>=3.4.0,<3.5.0",
|
||||
"tox>=4.11.1,<4.12.0",
|
||||
"ruff>=0.7.4,<0.8.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
arcade-math = "arcade_math.__main__:main"
|
||||
arcade_math = "arcade_math.__main__:main"
|
||||
|
||||
# Use local path sources for arcade libs when working locally
|
||||
[tool.uv.sources]
|
||||
arcade-mcp = {path = "../../", editable = true}
|
||||
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
files = [ "arcade_math/**/*.py",]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
check_untyped_defs = "True"
|
||||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = [ "tests",]
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [ "arcade_math",]
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
import pytest
|
||||
from arcade_mcp_server.exceptions import ToolExecutionError
|
||||
|
||||
from arcade_math.tools.arithmetic import (
|
||||
add,
|
||||
divide,
|
||||
mod,
|
||||
multiply,
|
||||
subtract,
|
||||
sum_list,
|
||||
sum_range,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a, b, expected",
|
||||
[
|
||||
("1", "2", "3"),
|
||||
("-1", "1", "0"),
|
||||
("0.5", "10.9", "11.4"),
|
||||
# Big ints
|
||||
("12345678901234567890", "9876543210987654321", "22222222112222222211"),
|
||||
# Big floats
|
||||
(
|
||||
"12345678901234567890.120",
|
||||
"9876543210987654321.987",
|
||||
"22222222112222222212.107",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_add(a, b, expected):
|
||||
assert add(a, b) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a, b, expected",
|
||||
[
|
||||
("1", "2", "-1"),
|
||||
("-1", "1", "-2"),
|
||||
("0.5", "10.9", "-10.4"),
|
||||
# Big ints
|
||||
("12345678901234567890", "12323456679012345668", "22222222222222222"),
|
||||
# Big floats
|
||||
(
|
||||
"12345678901234567890.120",
|
||||
"12343557689113355768.9079",
|
||||
"2121212121212121.2121",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_subtract(a, b, expected):
|
||||
assert subtract(a, b) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a, b, expected",
|
||||
[
|
||||
("-1", "2", "-2"),
|
||||
("-10", "0", "-0"),
|
||||
("0.5", "10.9", "5.45"),
|
||||
# Big ints
|
||||
(
|
||||
"12345678901234567890",
|
||||
"18000000162000001474380013420000",
|
||||
"222222222222222222222222222261233060226101083800000",
|
||||
),
|
||||
# Big floats
|
||||
(
|
||||
"12345678901234567890.120",
|
||||
"12345678901234567890.120",
|
||||
"152415787532388367504868162811315348393.614400",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multiply(a, b, expected):
|
||||
assert multiply(a, b) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a, b, expected",
|
||||
[
|
||||
("-1", "2", "-0.5"),
|
||||
("-10", "1", "-10"),
|
||||
(
|
||||
"0.5",
|
||||
"10.9",
|
||||
"0.0458715596330275229357798165137614678899082568807339"
|
||||
"4495412844036697247706422018348623853211009174312",
|
||||
),
|
||||
# Big ints
|
||||
("152407406035740740602050", "12345678901234567890", "12345"),
|
||||
# Big floats
|
||||
(
|
||||
"152407406035740740603531.400",
|
||||
"12345678901234567890.120",
|
||||
"12345",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_divide(a, b, expected):
|
||||
assert divide(a, b) == expected
|
||||
|
||||
|
||||
def text_zero_division():
|
||||
with pytest.raises(ToolExecutionError):
|
||||
divide("1", "0")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
divide("1", "0.0")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
divide("1", "0.000000")
|
||||
|
||||
|
||||
def test_sum_list():
|
||||
assert sum_list(["1", "2", "3", "4", "5", "6"]) == "21"
|
||||
assert sum_list([]) == "0"
|
||||
assert sum_list(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-21"
|
||||
assert sum_list(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "2.1"
|
||||
|
||||
|
||||
def test_sum_range():
|
||||
assert sum_range("8", "2") == "0"
|
||||
assert sum_range("-8", "2") == "-33"
|
||||
assert sum_range("8", "-2") == "0"
|
||||
assert sum_range("2", "3") == "5"
|
||||
assert sum_range("0", "10") == "55"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sum_range("2", "0.5")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sum_range("-1", "0.5")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sum_range("2.", "0.5")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sum_range("-1", "0.5")
|
||||
|
||||
|
||||
def test_mod():
|
||||
assert mod("-1", "0.5") == "-0.0"
|
||||
assert mod("-8", "2") == "-0"
|
||||
assert mod("0", "10") == "0"
|
||||
assert mod("2", "0.5") == "0.0"
|
||||
assert mod("2", "3") == "2"
|
||||
assert mod("2.", "-0.5") == "0.0"
|
||||
assert mod("2.1234", "0.6") == "0.3234"
|
||||
assert mod("2.1234", "1") == "0.1234"
|
||||
assert mod("2.1234", "3") == "2.1234"
|
||||
assert mod("8", "-2") == "0"
|
||||
assert mod("8", "2") == "0"
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
import pytest
|
||||
from arcade_mcp_server.exceptions import ToolExecutionError
|
||||
|
||||
from arcade_math.tools.exponents import (
|
||||
log,
|
||||
power,
|
||||
)
|
||||
|
||||
|
||||
def test_log():
|
||||
assert log("8", "2") == "3.0"
|
||||
assert log("2", "3") == "0.6309297535714574"
|
||||
assert log("2", "0.5") == "-1.0"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
log("-1", "0.5")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
log("0", "10")
|
||||
|
||||
|
||||
def test_power():
|
||||
assert power("-8", "2") == "64"
|
||||
assert power("0", "10") == "0"
|
||||
assert (
|
||||
power("2", "0.5") == "1.41421356237309504880168872420969807856"
|
||||
"9671875376948073176679737990732478462107038850387534327641573"
|
||||
)
|
||||
assert power("2", "3") == "8"
|
||||
assert (
|
||||
power("2.", "-0.5") == "0.707106781186547524400844362104849039"
|
||||
"2848359376884740365883398689953662392310535194251937671638207864"
|
||||
)
|
||||
assert (
|
||||
power("2.1234", "0.6") == "1.571155202490495156807227174573016145"
|
||||
"282682479346448636509576776014844055570115193494685328114403375"
|
||||
)
|
||||
assert power("2.1234", "1") == "2.1234"
|
||||
assert power("2.1234", "3") == "9.574044440904"
|
||||
assert power("8", "-2") == "0.015625"
|
||||
assert power("8", "2") == "64"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
power("-1", "0.5")
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
import pytest
|
||||
from arcade_mcp_server.exceptions import ToolExecutionError
|
||||
|
||||
from arcade_math.tools.miscellaneous import (
|
||||
abs_val,
|
||||
factorial,
|
||||
sqrt,
|
||||
)
|
||||
|
||||
|
||||
def test_abs_val():
|
||||
assert abs_val("2") == "2"
|
||||
assert abs_val("-1") == "1"
|
||||
assert abs_val("-1.12341234") == "1.12341234"
|
||||
|
||||
|
||||
def test_factorial():
|
||||
assert factorial("1") == "1"
|
||||
assert factorial("0") == "1"
|
||||
assert factorial("-0") == "1"
|
||||
assert factorial("23") == "25852016738884976640000"
|
||||
assert factorial("24") == "620448401733239439360000"
|
||||
assert factorial("10") == "3628800"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
factorial("-1")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
factorial("-10")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
factorial("0.0000")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
factorial("-0.0")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
factorial("1.0")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
factorial("-1.0")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
factorial("23.0")
|
||||
|
||||
|
||||
def test_sqrt():
|
||||
assert sqrt("1") == "1"
|
||||
assert sqrt("0") == "0"
|
||||
assert sqrt("-0") == "-0"
|
||||
assert (
|
||||
sqrt("23") == "4.79583152331271954159743806416269391999670704190"
|
||||
"4129346485309114448257235907464082492191446436918861"
|
||||
)
|
||||
assert (
|
||||
sqrt("24") == "4.89897948556635619639456814941178278393189496131"
|
||||
"3340256865385134501920754914630053079718866209280470"
|
||||
)
|
||||
assert (
|
||||
sqrt("10") == "3.16227766016837933199889354443271853371955513932"
|
||||
"5216826857504852792594438639238221344248108379300295"
|
||||
)
|
||||
assert sqrt("0.0") == "0.0"
|
||||
assert sqrt("0.0000") == "0.00"
|
||||
assert sqrt("-0.0") == "-0.0"
|
||||
assert sqrt("1.0") == "1.0"
|
||||
assert (
|
||||
sqrt("3.14") == "1.772004514666935040199112509753631525073608516"
|
||||
"162942966817771970290992972348902551472561151153909188"
|
||||
)
|
||||
assert (
|
||||
sqrt("0.4") == "0.6324555320336758663997787088865437067439110278"
|
||||
"650433653715009705585188877278476442688496216758600590"
|
||||
)
|
||||
assert (
|
||||
sqrt("10.0") == "3.162277660168379331998893544432718533719555139"
|
||||
"325216826857504852792594438639238221344248108379300295"
|
||||
)
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sqrt("-1")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sqrt("-10")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sqrt("-1.0")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sqrt("-1.3")
|
||||
with pytest.raises(ToolExecutionError):
|
||||
sqrt("-10.0")
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
import pytest
|
||||
from arcade_mcp_server.exceptions import ToolExecutionError
|
||||
|
||||
from arcade_math.tools.rational import (
|
||||
gcd,
|
||||
lcm,
|
||||
)
|
||||
|
||||
|
||||
def test_gcd():
|
||||
assert gcd("-15", "-5") == "5"
|
||||
assert gcd("15", "0") == "15"
|
||||
assert gcd("15", "-2") == "1"
|
||||
assert gcd("15", "-0") == "15"
|
||||
assert gcd("15", "5") == "5"
|
||||
assert gcd("7", "13") == "1"
|
||||
assert gcd("-13", "13") == "13"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
gcd("15.0", "5.0")
|
||||
|
||||
|
||||
def test_lcm():
|
||||
assert lcm("-15", "-5") == "15"
|
||||
assert lcm("15", "0") == "0"
|
||||
assert lcm("15", "-2") == "30"
|
||||
assert lcm("15", "-0") == "0"
|
||||
assert lcm("15", "5") == "15"
|
||||
assert lcm("7", "13") == "91"
|
||||
assert lcm("-13", "13") == "13"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
lcm("15.0", "5.0")
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
from arcade_math.tools.rounding import (
|
||||
ceil,
|
||||
floor,
|
||||
round_num,
|
||||
)
|
||||
|
||||
|
||||
def test_ceil():
|
||||
assert ceil("1") == "1"
|
||||
assert ceil("-1") == "-1"
|
||||
assert ceil("0") == "0"
|
||||
assert ceil("-0") == "0"
|
||||
assert ceil("0.0") == "0"
|
||||
assert ceil("0.0000") == "0"
|
||||
assert ceil("-0.0") == "0"
|
||||
assert ceil("1.0") == "1"
|
||||
assert ceil("-1.0") == "-1"
|
||||
assert ceil("3.14") == "4"
|
||||
assert ceil("0.4") == "1"
|
||||
assert ceil("-1.3") == "-1"
|
||||
|
||||
|
||||
def test_floor():
|
||||
assert floor("1") == "1"
|
||||
assert floor("-1") == "-1"
|
||||
assert floor("0") == "0"
|
||||
assert floor("-0") == "0"
|
||||
assert floor("10") == "10"
|
||||
assert floor("0.0") == "0"
|
||||
assert floor("0.0000") == "0"
|
||||
assert floor("-0.0") == "0"
|
||||
assert floor("1.0") == "1"
|
||||
assert floor("-1.0") == "-1"
|
||||
assert floor("3.14") == "3"
|
||||
assert floor("0.4") == "0"
|
||||
assert floor("-1.3") == "-2"
|
||||
|
||||
|
||||
def test_round_num():
|
||||
# TODO(mateo): ok with scientific notatin? ok with negative round digits?
|
||||
assert round_num("1.2345", "-2") == "0"
|
||||
assert round_num("1.2345", "-1") == "0"
|
||||
assert round_num("1.2345", "0") == "1"
|
||||
assert round_num("1.2345", "1") == "1.2"
|
||||
assert round_num("1.2345", "2") == "1.23"
|
||||
assert round_num("1.2345", "3") == "1.234"
|
||||
assert round_num("1.2345", "8") == "1.23450000"
|
||||
assert round_num("1.654321", "-2") == "0"
|
||||
assert round_num("1.654321", "-1") == "0"
|
||||
assert round_num("1.654321", "0") == "2"
|
||||
assert round_num("1.654321", "1") == "1.7"
|
||||
assert round_num("1.654321", "2") == "1.65"
|
||||
assert round_num("1.654321", "3") == "1.654"
|
||||
assert round_num("1.654321", "8") == "1.65432100"
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from arcade_math.tools.statistics import (
|
||||
avg,
|
||||
median,
|
||||
)
|
||||
|
||||
|
||||
def test_avg():
|
||||
assert avg(["1", "2", "3", "4", "5", "6"]) == "3.5"
|
||||
assert avg([]) == "0.0"
|
||||
assert avg(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-3.5"
|
||||
assert avg(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "0.35"
|
||||
|
||||
|
||||
def test_median():
|
||||
assert median(["1", "2", "3", "4", "5", "6"]) == "3.5"
|
||||
assert median([]) == "0.0"
|
||||
assert median(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-3.5"
|
||||
assert median(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "0.3"
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
from arcade_math.tools.trigonometry import (
|
||||
deg_to_rad,
|
||||
rad_to_deg,
|
||||
)
|
||||
|
||||
|
||||
def test_deg_to_rad():
|
||||
assert deg_to_rad("1") == "0.017453292519943295"
|
||||
assert deg_to_rad("-1") == "-0.017453292519943295"
|
||||
assert deg_to_rad("0") == "0.0"
|
||||
assert deg_to_rad("-0") == "-0.0"
|
||||
assert deg_to_rad("23") == "0.4014257279586958"
|
||||
assert deg_to_rad("24") == "0.4188790204786391"
|
||||
assert deg_to_rad("-10") == "-0.17453292519943295"
|
||||
assert deg_to_rad("10") == "0.17453292519943295"
|
||||
assert deg_to_rad("180") == "3.141592653589793"
|
||||
assert deg_to_rad("0.0") == "0.0"
|
||||
assert deg_to_rad("0.0000") == "0.0"
|
||||
assert deg_to_rad("-0.0") == "-0.0"
|
||||
assert deg_to_rad("1.0") == "0.017453292519943295"
|
||||
assert deg_to_rad("-1.0") == "-0.017453292519943295"
|
||||
assert deg_to_rad("23.0") == "0.4014257279586958"
|
||||
assert deg_to_rad("0.4") == "0.006981317007977318"
|
||||
assert deg_to_rad("-10.0") == "-0.17453292519943295"
|
||||
assert deg_to_rad("10.0") == "0.17453292519943295"
|
||||
|
||||
|
||||
def test_rad_to_deg():
|
||||
assert rad_to_deg("1") == "57.29577951308232"
|
||||
assert rad_to_deg("-1") == "-57.29577951308232"
|
||||
assert rad_to_deg("0") == "0.0"
|
||||
assert rad_to_deg("-0") == "-0.0"
|
||||
assert rad_to_deg("23") == "1317.8029288008934"
|
||||
assert rad_to_deg("24") == "1375.0987083139757"
|
||||
assert rad_to_deg("-10") == "-572.9577951308232"
|
||||
assert rad_to_deg("10") == "572.9577951308232"
|
||||
assert rad_to_deg("0.0") == "0.0"
|
||||
assert rad_to_deg("0.0000") == "0.0"
|
||||
assert rad_to_deg("-0.0") == "-0.0"
|
||||
assert rad_to_deg("1.0") == "57.29577951308232"
|
||||
assert rad_to_deg("-1.0") == "-57.29577951308232"
|
||||
assert rad_to_deg("3.14") == "179.9087476710785"
|
||||
assert rad_to_deg("0.4") == "22.918311805232932"
|
||||
assert rad_to_deg("-10.0") == "-572.9577951308232"
|
||||
assert rad_to_deg("10.0") == "572.9577951308232"
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ github Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the uv environment and install all packages with dependencies
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras --no-sources
|
||||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: install-local
|
||||
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras
|
||||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
uv build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
coverage report
|
||||
@echo "Generating coverage report"
|
||||
coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
uv version --bump patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@echo "🚀 Linting code: Running pre-commit"
|
||||
@uv run pre-commit run -a
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@uv run mypy --config-file=pyproject.toml
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import sys
|
||||
from typing import cast
|
||||
|
||||
from arcade_mcp_server import MCPApp
|
||||
from arcade_mcp_server.mcp_app import TransportType
|
||||
|
||||
import arcade_mongodb
|
||||
|
||||
app = MCPApp(
|
||||
name="MongoDB",
|
||||
instructions=(
|
||||
"Use this server when you need to interact with MongoDB to help users "
|
||||
"query, explore, and manage their MongoDB databases and collections."
|
||||
),
|
||||
)
|
||||
|
||||
app.add_tools_from_module(arcade_mongodb)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
|
||||
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
|
||||
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
|
||||
|
||||
app.run(transport=cast(TransportType, transport), host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
from typing import Any, ClassVar
|
||||
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||
from pymongo.errors import ServerSelectionTimeoutError
|
||||
|
||||
MAX_RECORDS_RETURNED = 1000
|
||||
TEST_QUERY = {"ping": 1}
|
||||
|
||||
|
||||
class DatabaseEngine:
|
||||
_instance: ClassVar[None] = None
|
||||
_clients: ClassVar[dict[str, AsyncIOMotorClient]] = {}
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, connection_string: str) -> AsyncIOMotorClient:
|
||||
key = connection_string
|
||||
if key not in cls._clients:
|
||||
cls._clients[key] = AsyncIOMotorClient(connection_string)
|
||||
|
||||
# try a simple query to see if the connection is valid
|
||||
try:
|
||||
admin_db = cls._clients[key].admin
|
||||
await admin_db.command(TEST_QUERY)
|
||||
return cls._clients[key]
|
||||
except ServerSelectionTimeoutError:
|
||||
# close and try again
|
||||
cls._clients[key].close()
|
||||
cls._clients[key] = AsyncIOMotorClient(connection_string)
|
||||
|
||||
try:
|
||||
admin_db = cls._clients[key].admin
|
||||
await admin_db.command(TEST_QUERY)
|
||||
return cls._clients[key]
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Connection failed: {e}",
|
||||
developer_message="Connection to MongoDB failed.",
|
||||
additional_prompt_content="Check the connection string and try again.",
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
async def get_database(cls, connection_string: str, database_name: str) -> Any:
|
||||
client = await cls.get_instance(connection_string)
|
||||
|
||||
class DatabaseContextManager:
|
||||
def __init__(self, client: AsyncIOMotorClient, database_name: str) -> None:
|
||||
self.client = client
|
||||
self.database_name = database_name
|
||||
self.database = client[database_name]
|
||||
|
||||
async def __aenter__(self) -> AsyncIOMotorDatabase:
|
||||
return self.database
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
# Connection cleanup is handled by the client cache
|
||||
pass
|
||||
|
||||
return DatabaseContextManager(client, database_name)
|
||||
|
||||
@classmethod
|
||||
async def cleanup(cls) -> None:
|
||||
"""Clean up all cached clients. Call this when shutting down."""
|
||||
for client in cls._clients.values():
|
||||
client.close()
|
||||
cls._clients.clear()
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the client cache without closing clients. Use with caution."""
|
||||
cls._clients.clear()
|
||||
|
||||
@classmethod
|
||||
def sanitize_query_params(
|
||||
cls,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
filter_dict: dict[str, Any] | None,
|
||||
projection: dict[str, Any] | None,
|
||||
sort: list[dict[str, Any]] | None,
|
||||
limit: int,
|
||||
skip: int,
|
||||
) -> tuple[
|
||||
str, str, dict[str, Any], dict[str, Any] | None, list[dict[str, Any]] | None, int, int
|
||||
]:
|
||||
if not database_name:
|
||||
raise RetryableToolError(
|
||||
"Database name is required.",
|
||||
developer_message="Database name cannot be empty.",
|
||||
)
|
||||
|
||||
if not collection_name:
|
||||
raise RetryableToolError(
|
||||
"Collection name is required.",
|
||||
developer_message="Collection name cannot be empty.",
|
||||
)
|
||||
|
||||
if filter_dict is None:
|
||||
filter_dict = {}
|
||||
|
||||
if limit > MAX_RECORDS_RETURNED:
|
||||
raise RetryableToolError(
|
||||
f"Limit is too high. Maximum is {MAX_RECORDS_RETURNED}.",
|
||||
)
|
||||
|
||||
if skip < 0:
|
||||
raise RetryableToolError(
|
||||
"Skip must be greater than or equal to 0.",
|
||||
developer_message="Skip must be greater than or equal to 0.",
|
||||
)
|
||||
|
||||
if limit <= 0:
|
||||
raise RetryableToolError(
|
||||
"Limit must be greater than 0.",
|
||||
developer_message="Limit must be greater than 0.",
|
||||
)
|
||||
|
||||
return database_name, collection_name, filter_dict, projection, sort, limit, skip
|
||||
|
|
@ -1,434 +0,0 @@
|
|||
import json
|
||||
from typing import Annotated, Any
|
||||
|
||||
from arcade_mcp_server import Context, tool
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata
|
||||
|
||||
from ..database_engine import MAX_RECORDS_RETURNED, DatabaseEngine
|
||||
from .utils import (
|
||||
_infer_schema_from_docs,
|
||||
_parse_json_list_parameter,
|
||||
_parse_json_parameter,
|
||||
_serialize_document,
|
||||
)
|
||||
|
||||
# class UserStatus(str, Enum):
|
||||
# """User status enumeration."""
|
||||
|
||||
# ACTIVE = "active"
|
||||
# INACTIVE = "inactive"
|
||||
# BANNED = "banned"
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["MONGODB_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def discover_databases(
|
||||
context: Context,
|
||||
) -> list[str]:
|
||||
"""Discover all the databases in the MongoDB instance."""
|
||||
client = await DatabaseEngine.get_instance(context.get_secret("MONGODB_CONNECTION_STRING"))
|
||||
databases = await client.list_database_names()
|
||||
# Filter out admin and config databases by default
|
||||
databases = [db for db in databases if db not in ["admin", "config", "local"]]
|
||||
return databases
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["MONGODB_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def discover_collections(
|
||||
context: Context,
|
||||
database_name: Annotated[str, "The database name to discover collections in"],
|
||||
) -> list[str]:
|
||||
"""Discover all the collections in the MongoDB database when the list of collections is not known.
|
||||
|
||||
ALWAYS use this tool before any other tool that requires a collection name.
|
||||
"""
|
||||
async with await DatabaseEngine.get_database(
|
||||
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
|
||||
) as db:
|
||||
collections = await db.list_collection_names()
|
||||
return list(collections)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["MONGODB_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def get_collection_schema(
|
||||
context: Context,
|
||||
database_name: Annotated[str, "The database name to get the collection schema of"],
|
||||
collection_name: Annotated[str, "The collection to get the schema of"],
|
||||
sample_size: Annotated[
|
||||
int,
|
||||
f"The number of documents to sample for schema discovery (default: {MAX_RECORDS_RETURNED})",
|
||||
] = MAX_RECORDS_RETURNED,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the schema/structure of a MongoDB collection by sampling documents.
|
||||
|
||||
Since MongoDB is schema-less, this tool samples a configurable number of documents
|
||||
to infer the schema structure and data types.
|
||||
|
||||
This tool should ALWAYS be used before executing any query. All collections in the query must be discovered first using the <discover_collections> tool.
|
||||
"""
|
||||
async with await DatabaseEngine.get_database(
|
||||
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
|
||||
) as db:
|
||||
collection = db[collection_name]
|
||||
|
||||
# Sample documents at random to infer schema
|
||||
# Use MongoDB's $sample aggregation to get random documents
|
||||
sample_docs = []
|
||||
async for doc in collection.aggregate([{"$sample": {"size": sample_size}}]):
|
||||
sample_docs.append(doc)
|
||||
|
||||
if not sample_docs:
|
||||
return {"message": "Collection is empty", "schema": {}}
|
||||
|
||||
# Infer schema from sampled documents
|
||||
schema = _infer_schema_from_docs(sample_docs)
|
||||
|
||||
return {
|
||||
"total_documents_sampled": len(sample_docs),
|
||||
"sample_size_requested": sample_size,
|
||||
"schema": schema,
|
||||
}
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["MONGODB_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def find_documents(
|
||||
context: Context,
|
||||
database_name: Annotated[str, "The database name to query"],
|
||||
collection_name: Annotated[str, "The collection name to query"],
|
||||
filter_dict: Annotated[
|
||||
str | None,
|
||||
'MongoDB filter/query as JSON string. Leave None for no filter (find all documents). Example: \'{"status": "active", "age": {"$gte": 18}}\'',
|
||||
] = None,
|
||||
projection: Annotated[
|
||||
str | None,
|
||||
'Fields to include/exclude as JSON string. Use 1 to include, 0 to exclude. Example: \'{"name": 1, "email": 1, "_id": 0}\'. Leave None to include all fields.',
|
||||
] = None,
|
||||
sort: Annotated[
|
||||
list[str] | None,
|
||||
'Sort criteria as list of JSON strings, each containing \'field\' and \'direction\' keys. Use 1 for ascending, -1 for descending. Example: [\'{"field": "name", "direction": 1}\', \'{"field": "created_at", "direction": -1}\']',
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
f"The maximum number of documents to return. Default: {MAX_RECORDS_RETURNED}.",
|
||||
] = MAX_RECORDS_RETURNED,
|
||||
skip: Annotated[int, "The number of documents to skip. Default: 0."] = 0,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Find documents in a MongoDB collection.
|
||||
|
||||
ONLY use this tool if you have already loaded the schema of the collection you need to query.
|
||||
Use the <get_collection_schema> tool to load the schema if not already known.
|
||||
|
||||
Returns a list of JSON strings, where each string represents a document from the collection (tools cannot return complex types).
|
||||
|
||||
When running queries, follow these rules which will help avoid errors:
|
||||
* Always specify projection to limit fields returned if you don't need all data.
|
||||
* Always sort your results by the most important fields first. If you aren't sure, sort by '_id'.
|
||||
* Use appropriate MongoDB query operators for complex filtering ($gte, $lte, $in, $regex, etc.).
|
||||
* Be mindful of case sensitivity when querying string fields.
|
||||
* Use indexes when possible (typically on _id and commonly queried fields).
|
||||
"""
|
||||
# Initialize variables to avoid UnboundLocalError in exception handler
|
||||
parsed_filter = None
|
||||
parsed_projection = None
|
||||
parsed_sort = None
|
||||
|
||||
try:
|
||||
# Parse JSON string inputs
|
||||
parsed_filter = _parse_json_parameter(filter_dict, "filter_dict")
|
||||
parsed_projection = _parse_json_parameter(projection, "projection")
|
||||
parsed_sort = _parse_json_list_parameter(sort, "sort")
|
||||
|
||||
(
|
||||
database_name,
|
||||
collection_name,
|
||||
parsed_filter,
|
||||
parsed_projection,
|
||||
parsed_sort,
|
||||
limit,
|
||||
skip,
|
||||
) = DatabaseEngine.sanitize_query_params(
|
||||
database_name=database_name,
|
||||
collection_name=collection_name,
|
||||
filter_dict=parsed_filter,
|
||||
projection=parsed_projection,
|
||||
sort=parsed_sort,
|
||||
limit=limit,
|
||||
skip=skip,
|
||||
)
|
||||
|
||||
async with await DatabaseEngine.get_database(
|
||||
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
|
||||
) as db:
|
||||
collection = db[collection_name]
|
||||
|
||||
# Build the query
|
||||
cursor = collection.find(parsed_filter, parsed_projection)
|
||||
|
||||
if parsed_sort:
|
||||
# Convert list of dicts to list of tuples for MongoDB sort
|
||||
sort_tuples = [(str(item["field"]), int(item["direction"])) for item in parsed_sort]
|
||||
cursor = cursor.sort(sort_tuples)
|
||||
|
||||
cursor = cursor.skip(skip).limit(limit)
|
||||
|
||||
# Execute query and collect results
|
||||
documents = []
|
||||
async for doc in cursor:
|
||||
# Convert ObjectId and other non-serializable types to strings
|
||||
doc = _serialize_document(doc)
|
||||
documents.append(json.dumps(doc))
|
||||
|
||||
return documents
|
||||
|
||||
except RetryableToolError:
|
||||
# Re-raise RetryableToolError as-is to preserve JSON validation messages
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Query failed: {e}",
|
||||
developer_message=f"Query failed with parameters: database_name={database_name}, collection_name={collection_name}, filter_dict={parsed_filter}, projection={parsed_projection}, sort={parsed_sort}, limit={limit}, skip={skip}.",
|
||||
additional_prompt_content="Load the collection schema <get_collection_schema> or use the <discover_collections> tool to discover the collections and try again.",
|
||||
retry_after_ms=10,
|
||||
) from e
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["MONGODB_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def count_documents(
|
||||
context: Context,
|
||||
database_name: Annotated[str, "The database name to query"],
|
||||
collection_name: Annotated[str, "The collection name to query"],
|
||||
filter_dict: Annotated[
|
||||
str | None,
|
||||
'MongoDB filter/query as JSON string. Leave None for no filter (count all documents). Example: \'{"status": "active"}\'',
|
||||
] = None,
|
||||
) -> int:
|
||||
"""Count documents in a MongoDB collection matching the given filter."""
|
||||
parsed_filter = None
|
||||
|
||||
try:
|
||||
# Parse JSON string input
|
||||
parsed_filter = _parse_json_parameter(filter_dict, "filter_dict") or {}
|
||||
|
||||
async with await DatabaseEngine.get_database(
|
||||
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
|
||||
) as db:
|
||||
collection = db[collection_name]
|
||||
|
||||
count = await collection.count_documents(parsed_filter)
|
||||
return int(count)
|
||||
|
||||
except RetryableToolError:
|
||||
# Re-raise RetryableToolError as-is to preserve JSON validation messages
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Count query failed: {e}",
|
||||
developer_message=f"Count query failed with parameters: database_name={database_name}, collection_name={collection_name}, filter_dict={parsed_filter}.",
|
||||
additional_prompt_content="Check the collection name and filter criteria and try again.",
|
||||
retry_after_ms=10,
|
||||
) from e
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["MONGODB_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def aggregate_documents(
|
||||
context: Context,
|
||||
database_name: Annotated[str, "The database name to query"],
|
||||
collection_name: Annotated[str, "The collection name to query"],
|
||||
pipeline: Annotated[
|
||||
list[str],
|
||||
'MongoDB aggregation pipeline as a list of JSON strings, each representing a stage. Example: [\'{"$match": {"status": "active"}}\', \'{"$group": {"_id": "$category", "count": {"$sum": 1}}}\']',
|
||||
],
|
||||
limit: Annotated[
|
||||
int,
|
||||
f"The maximum number of results to return from the aggregation. Default: {MAX_RECORDS_RETURNED}.",
|
||||
] = MAX_RECORDS_RETURNED,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Execute a MongoDB aggregation pipeline on a collection.
|
||||
|
||||
ONLY use this tool if you have already loaded the schema of the collection you need to query.
|
||||
Use the <get_collection_schema> tool to load the schema if not already known.
|
||||
|
||||
Returns a list of JSON strings, where each string represents a result document from the aggregation (tools cannot return complex types).
|
||||
|
||||
Aggregation pipelines allow for complex data processing including:
|
||||
* $match - filter documents
|
||||
* $group - group documents and perform calculations
|
||||
* $project - reshape documents
|
||||
* $sort - sort documents
|
||||
* $limit - limit results
|
||||
* $lookup - join with other collections
|
||||
* And many more stages
|
||||
"""
|
||||
parsed_pipeline = None
|
||||
|
||||
try:
|
||||
# Parse JSON string inputs
|
||||
parsed_pipeline = _parse_json_list_parameter(pipeline, "pipeline")
|
||||
|
||||
if parsed_pipeline is None:
|
||||
raise RetryableToolError( # noqa: TRY301
|
||||
"Pipeline cannot be empty",
|
||||
developer_message="The pipeline parameter is required and cannot be None",
|
||||
)
|
||||
|
||||
async with await DatabaseEngine.get_database(
|
||||
context.get_secret("MONGODB_CONNECTION_STRING"), database_name
|
||||
) as db:
|
||||
collection = db[collection_name]
|
||||
|
||||
# Add limit to pipeline if not already present
|
||||
pipeline_with_limit = parsed_pipeline.copy()
|
||||
has_limit = any("$limit" in stage for stage in pipeline_with_limit)
|
||||
if not has_limit:
|
||||
pipeline_with_limit.append({"$limit": limit})
|
||||
|
||||
# Execute aggregation
|
||||
cursor = collection.aggregate(pipeline_with_limit)
|
||||
|
||||
documents = []
|
||||
async for doc in cursor:
|
||||
# Convert ObjectId and other non-serializable types to strings
|
||||
doc = _serialize_document(doc)
|
||||
documents.append(json.dumps(doc))
|
||||
|
||||
return documents
|
||||
|
||||
except RetryableToolError:
|
||||
# Re-raise RetryableToolError as-is to preserve JSON validation messages
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Aggregation query failed: {e}",
|
||||
developer_message=f"Aggregation query failed with parameters: database_name={database_name}, collection_name={collection_name}, pipeline={parsed_pipeline}, limit={limit}.",
|
||||
additional_prompt_content="Check the aggregation pipeline syntax and collection schema, then try again.",
|
||||
retry_after_ms=10,
|
||||
) from e
|
||||
|
||||
|
||||
# @tool(requires_secrets=["MONGODB_CONNECTION_STRING"])
|
||||
# async def update_user_status(
|
||||
# context: ToolContext,
|
||||
# database_name: Annotated[str, "The database name containing the users collection"],
|
||||
# collection_name: Annotated[str, "The collection name containing user documents"],
|
||||
# user_id: Annotated[str, "The _id of the user to update"],
|
||||
# status: Annotated[UserStatus, "The new status for the user"],
|
||||
# ) -> dict[str, Any]:
|
||||
# """
|
||||
# [CUSTOM TOOL]
|
||||
# Update the status of a user in the MongoDB collection.
|
||||
|
||||
# This tool updates a user document by setting the status field to the specified value.
|
||||
# The status must be one of: active, inactive, or banned.
|
||||
|
||||
# Returns information about the update operation including the number of documents modified.
|
||||
# """
|
||||
|
||||
# try:
|
||||
# async with await DatabaseEngine.get_database(
|
||||
# context.get_secret("MONGODB_CONNECTION_STRING"), database_name
|
||||
# ) as db:
|
||||
# collection = db[collection_name]
|
||||
|
||||
# # cast the user_id to int if it looks like an integer
|
||||
# if isinstance(user_id, str) and user_id.isdigit():
|
||||
# user_id = int(user_id)
|
||||
|
||||
# result = await collection.update_one(
|
||||
# {"_id": user_id}, {"$set": {"status": status.value}}
|
||||
# )
|
||||
|
||||
# print(result)
|
||||
|
||||
# if result.matched_count == 0:
|
||||
# return {
|
||||
# "success": False,
|
||||
# "message": f"No user found with _id: {user_id}",
|
||||
# "matched_count": 0,
|
||||
# "modified_count": 0,
|
||||
# }
|
||||
|
||||
# return {
|
||||
# "success": True,
|
||||
# "message": f"User status updated to '{status.value}'",
|
||||
# "user_id": user_id,
|
||||
# "new_status": status.value,
|
||||
# "matched_count": result.matched_count,
|
||||
# "modified_count": result.modified_count,
|
||||
# }
|
||||
|
||||
# except Exception as e:
|
||||
# raise RetryableToolError(
|
||||
# f"Failed to update user status: {e}",
|
||||
# developer_message=f"Update operation failed with parameters: database_name={database_name}, collection_name={collection_name}, user_id={user_id}, status={status}.",
|
||||
# additional_prompt_content="Check the database name, collection name, and user ID, then try again.",
|
||||
# retry_after_ms=10,
|
||||
# ) from e
|
||||
|
|
@ -1,281 +0,0 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
def _validate_no_write_operations(obj: Any, parameter_name: str, path: str = "") -> None:
|
||||
"""
|
||||
Recursively validate that an object doesn't contain MongoDB write operations.
|
||||
|
||||
Args:
|
||||
obj: The object to validate
|
||||
parameter_name: Name of the parameter for error messages
|
||||
path: Current path in the object (for nested validation)
|
||||
|
||||
Raises:
|
||||
RetryableToolError: If write operations are detected
|
||||
"""
|
||||
# MongoDB write/update operators that should be blocked
|
||||
WRITE_OPERATORS = {
|
||||
# Update operators
|
||||
"$set",
|
||||
"$unset",
|
||||
"$inc",
|
||||
"$mul",
|
||||
"$rename",
|
||||
"$min",
|
||||
"$max",
|
||||
"$currentDate",
|
||||
"$addToSet",
|
||||
"$pop",
|
||||
"$pull",
|
||||
"$push",
|
||||
"$pullAll",
|
||||
"$each",
|
||||
"$slice",
|
||||
"$sort",
|
||||
"$position",
|
||||
"$bit",
|
||||
"$isolated",
|
||||
# Array update operators
|
||||
"$",
|
||||
"$[]",
|
||||
"$[<identifier>]",
|
||||
# Pipeline update operators
|
||||
"$addFields",
|
||||
"$replaceRoot",
|
||||
"$replaceWith",
|
||||
# Aggregation stages that can modify (in case they're misused)
|
||||
"$out",
|
||||
"$merge",
|
||||
# Other potentially dangerous operators
|
||||
"$where", # Can execute JavaScript
|
||||
}
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
current_path = f"{path}.{key}" if path else key
|
||||
|
||||
# Special check for $where operator which can execute JavaScript (check this first)
|
||||
if key == "$where":
|
||||
raise RetryableToolError(
|
||||
f"JavaScript execution operator '$where' not allowed in {parameter_name}",
|
||||
developer_message=f"Found '$where' operator at path '{current_path}' in parameter '{parameter_name}'. JavaScript execution is not allowed for security reasons.",
|
||||
additional_prompt_content=f"The {parameter_name} parameter cannot use the $where operator. Use other query operators instead.",
|
||||
)
|
||||
|
||||
# Check if this key is a write operator
|
||||
if key in WRITE_OPERATORS:
|
||||
raise RetryableToolError(
|
||||
f"Write operation '{key}' not allowed in {parameter_name}",
|
||||
developer_message=f"Found write operation '{key}' at path '{current_path}' in parameter '{parameter_name}'. Only read operations are allowed.",
|
||||
additional_prompt_content=f"The {parameter_name} parameter cannot contain write operations like '{key}'. Use only query/read operations such as $match, $gte, $lte, $in, $regex, etc.",
|
||||
)
|
||||
|
||||
# Recursively validate nested objects
|
||||
_validate_no_write_operations(value, parameter_name, current_path)
|
||||
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
current_path = f"{path}[{i}]" if path else f"[{i}]"
|
||||
_validate_no_write_operations(item, parameter_name, current_path)
|
||||
|
||||
|
||||
def _parse_json_parameter(
|
||||
json_string: str | None, parameter_name: str, validate_read_only: bool = True
|
||||
) -> Any | None:
|
||||
"""
|
||||
Parse a JSON string parameter with proper error handling and optional write operation validation.
|
||||
|
||||
Args:
|
||||
json_string: The JSON string to parse (can be None)
|
||||
parameter_name: Name of the parameter for error messages
|
||||
validate_read_only: Whether to validate that no write operations are present
|
||||
|
||||
Returns:
|
||||
Parsed JSON object or None if json_string is None
|
||||
|
||||
Raises:
|
||||
RetryableToolError: If JSON parsing fails or write operations are detected
|
||||
"""
|
||||
if json_string is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed_obj = json.loads(json_string)
|
||||
|
||||
# Validate that no write operations are present
|
||||
if validate_read_only and parsed_obj is not None:
|
||||
_validate_no_write_operations(parsed_obj, parameter_name)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise RetryableToolError(
|
||||
f"Invalid JSON in {parameter_name}: {e}",
|
||||
developer_message=f"Failed to parse JSON string for parameter '{parameter_name}': {json_string}. Error: {e}",
|
||||
additional_prompt_content=f"Please provide valid JSON for the {parameter_name} parameter. Check for proper escaping of quotes and valid JSON syntax.",
|
||||
) from e
|
||||
else:
|
||||
return parsed_obj
|
||||
|
||||
|
||||
def _validate_aggregation_pipeline(pipeline: list[Any], parameter_name: str) -> None:
|
||||
"""
|
||||
Validate that an aggregation pipeline only contains read operations.
|
||||
|
||||
Args:
|
||||
pipeline: The aggregation pipeline to validate
|
||||
parameter_name: Name of the parameter for error messages
|
||||
|
||||
Raises:
|
||||
RetryableToolError: If write operations are detected in the pipeline
|
||||
"""
|
||||
# MongoDB aggregation stages that can modify data
|
||||
WRITE_STAGES = {
|
||||
"$out",
|
||||
"$merge", # These stages write to collections
|
||||
}
|
||||
|
||||
# Aggregation stages that are potentially dangerous
|
||||
DANGEROUS_STAGES = {
|
||||
"$where", # Can execute JavaScript
|
||||
}
|
||||
|
||||
for i, stage in enumerate(pipeline):
|
||||
if isinstance(stage, dict):
|
||||
for stage_name in stage:
|
||||
if stage_name in WRITE_STAGES:
|
||||
raise RetryableToolError(
|
||||
f"Write stage '{stage_name}' not allowed in {parameter_name}",
|
||||
developer_message=f"Found write stage '{stage_name}' at pipeline index {i} in parameter '{parameter_name}'. Only read operations are allowed.",
|
||||
additional_prompt_content=f"The {parameter_name} parameter cannot contain write stages like '{stage_name}'. Use only read stages such as $match, $group, $project, $sort, $limit, etc.",
|
||||
)
|
||||
|
||||
if stage_name in DANGEROUS_STAGES:
|
||||
raise RetryableToolError(
|
||||
f"Dangerous stage '{stage_name}' not allowed in {parameter_name}",
|
||||
developer_message=f"Found dangerous stage '{stage_name}' at pipeline index {i} in parameter '{parameter_name}'. JavaScript execution is not allowed for security reasons.",
|
||||
additional_prompt_content=f"The {parameter_name} parameter cannot use the {stage_name} stage. Use other aggregation stages instead.",
|
||||
)
|
||||
|
||||
# Also validate the stage content for write operations
|
||||
_validate_no_write_operations(
|
||||
stage[stage_name], f"{parameter_name}[{i}].{stage_name}"
|
||||
)
|
||||
|
||||
|
||||
def _parse_json_list_parameter(
|
||||
json_strings: list[str] | None, parameter_name: str, validate_read_only: bool = True
|
||||
) -> list[Any] | None:
|
||||
"""
|
||||
Parse a list of JSON strings with proper error handling and optional write operation validation.
|
||||
|
||||
Args:
|
||||
json_strings: List of JSON strings to parse (can be None)
|
||||
parameter_name: Name of the parameter for error messages
|
||||
validate_read_only: Whether to validate that no write operations are present
|
||||
|
||||
Returns:
|
||||
List of parsed JSON objects or None if json_strings is None
|
||||
|
||||
Raises:
|
||||
RetryableToolError: If JSON parsing fails for any string or write operations are detected
|
||||
"""
|
||||
if json_strings is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed_list = [json.loads(json_str) for json_str in json_strings]
|
||||
|
||||
# Validate that no write operations are present
|
||||
if validate_read_only and parsed_list is not None:
|
||||
# Special handling for pipeline parameters
|
||||
if parameter_name == "pipeline":
|
||||
_validate_aggregation_pipeline(parsed_list, parameter_name)
|
||||
else:
|
||||
# For non-pipeline lists, validate each item
|
||||
for i, item in enumerate(parsed_list):
|
||||
_validate_no_write_operations(item, f"{parameter_name}[{i}]")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise RetryableToolError(
|
||||
f"Invalid JSON in {parameter_name}: {e}",
|
||||
developer_message=f"Failed to parse JSON string list for parameter '{parameter_name}': {json_strings}. Error: {e}",
|
||||
additional_prompt_content=f"Please provide valid JSON strings for the {parameter_name} parameter. Each string must be valid JSON with proper escaping of quotes.",
|
||||
) from e
|
||||
else:
|
||||
return parsed_list
|
||||
|
||||
|
||||
def _infer_schema_from_docs(docs: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Infer schema structure from a list of documents."""
|
||||
schema: dict[str, Any] = {}
|
||||
|
||||
for doc in docs:
|
||||
_update_schema_with_doc(schema, doc)
|
||||
|
||||
# Convert sets to lists for serialization
|
||||
for key in schema:
|
||||
if isinstance(schema[key]["types"], set):
|
||||
schema[key]["types"] = list(schema[key]["types"])
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _update_schema_with_doc(schema: dict[str, Any], doc: dict[str, Any], prefix: str = "") -> None:
|
||||
"""Recursively update schema with document structure."""
|
||||
for key, value in doc.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
|
||||
if full_key not in schema:
|
||||
schema[full_key] = {
|
||||
"types": set(),
|
||||
"sample_values": [],
|
||||
"null_count": 0,
|
||||
"total_count": 0,
|
||||
}
|
||||
|
||||
schema[full_key]["total_count"] += 1
|
||||
|
||||
if value is None:
|
||||
schema[full_key]["null_count"] += 1
|
||||
schema[full_key]["types"].add("null")
|
||||
else:
|
||||
value_type = type(value).__name__
|
||||
schema[full_key]["types"].add(value_type)
|
||||
|
||||
# Store sample values (limit to 3 unique samples)
|
||||
if (
|
||||
len(schema[full_key]["sample_values"]) < 3
|
||||
and value not in schema[full_key]["sample_values"]
|
||||
):
|
||||
schema[full_key]["sample_values"].append(value)
|
||||
|
||||
# Handle nested objects
|
||||
if isinstance(value, dict):
|
||||
_update_schema_with_doc(schema, value, full_key)
|
||||
elif isinstance(value, list) and value and isinstance(value[0], dict):
|
||||
# Handle arrays of objects by sampling the first few
|
||||
for i, item in enumerate(value[:3]): # Sample first 3 array items
|
||||
if isinstance(item, dict):
|
||||
_update_schema_with_doc(schema, item, f"{full_key}[{i}]")
|
||||
|
||||
|
||||
def _serialize_document(doc: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert MongoDB document to JSON-serializable format."""
|
||||
|
||||
if isinstance(doc, dict):
|
||||
result = {}
|
||||
for key, value in doc.items():
|
||||
result[key] = _serialize_document(value)
|
||||
return result
|
||||
elif isinstance(doc, list):
|
||||
return [_serialize_document(item) for item in doc]
|
||||
elif isinstance(doc, ObjectId):
|
||||
return str(doc)
|
||||
elif isinstance(doc, datetime):
|
||||
return doc.isoformat()
|
||||
else:
|
||||
return doc
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
# RUN ME WITH `uv run arcade evals evals --host api.arcade.dev`
|
||||
|
||||
import arcade_mongodb
|
||||
from arcade_core import ToolCatalog
|
||||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_mongodb.tools.mongodb import (
|
||||
aggregate_documents,
|
||||
count_documents,
|
||||
discover_collections,
|
||||
discover_databases,
|
||||
find_documents,
|
||||
get_collection_schema,
|
||||
)
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.85,
|
||||
warn_threshold=0.95,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_module(arcade_mongodb)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def mongodb_eval_suite() -> EvalSuite:
|
||||
suite = EvalSuite(
|
||||
name="MongoDB Tools Evaluation",
|
||||
system_message=(
|
||||
"You are an AI assistant with access to MongoDB tools. "
|
||||
"Use them to help the user with their tasks."
|
||||
),
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Discover databases",
|
||||
user_message="What databases are available in my MongoDB instance?",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(func=discover_databases, args={}),
|
||||
],
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Discover collections",
|
||||
user_message="What collections are in the 'admin' database?",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(func=discover_collections, args={"database_name": "admin"}),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="database_name", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get collection schema (single tool call)",
|
||||
user_message="Get the schema of the 'system.users' collection in the 'admin' database.",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_collection_schema,
|
||||
args={"database_name": "admin", "collection_name": "system.users"},
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="database_name", weight=0.5),
|
||||
BinaryCritic(critic_field="collection_name", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Find documents (direct call)",
|
||||
user_message="Find documents in the 'startup_log' collection of the 'local' database, limited to 5 results.",
|
||||
additional_messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "You can call find_documents directly without discovering collections first for this test.",
|
||||
}
|
||||
],
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=find_documents,
|
||||
args={
|
||||
"database_name": "local",
|
||||
"collection_name": "startup_log",
|
||||
"limit": 5,
|
||||
},
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="database_name", weight=0.33),
|
||||
BinaryCritic(critic_field="collection_name", weight=0.33),
|
||||
BinaryCritic(critic_field="limit", weight=0.34),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Count documents",
|
||||
user_message="Count all documents in the 'startup_log' collection of the 'local' database.",
|
||||
additional_messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "You can call count_documents directly without discovering collections first for this test.",
|
||||
}
|
||||
],
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=count_documents,
|
||||
args={
|
||||
"database_name": "local",
|
||||
"collection_name": "startup_log",
|
||||
},
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="database_name", weight=0.5),
|
||||
BinaryCritic(critic_field="collection_name", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Count documents with filter",
|
||||
user_message="Count documents in the 'startup_log' collection of the 'local' database where the level is 'INFO'.",
|
||||
additional_messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "You can call count_documents directly without discovering collections first for this test.",
|
||||
}
|
||||
],
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=count_documents,
|
||||
args={
|
||||
"database_name": "local",
|
||||
"collection_name": "startup_log",
|
||||
"filter_dict": '{"level": "INFO"}',
|
||||
},
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="database_name", weight=0.25),
|
||||
BinaryCritic(critic_field="collection_name", weight=0.25),
|
||||
SimilarityCritic(critic_field="filter_dict", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Aggregate documents",
|
||||
user_message="Group documents in the 'startup_log' collection of the 'local' database by level and count them.",
|
||||
additional_messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "You can call aggregate_documents directly without discovering collections first for this test.",
|
||||
}
|
||||
],
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=aggregate_documents,
|
||||
args={
|
||||
"database_name": "local",
|
||||
"collection_name": "startup_log",
|
||||
"pipeline": [
|
||||
'{"$group": {"_id": "$level", "count": {"$sum": 1}}}',
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="database_name", weight=0.2),
|
||||
BinaryCritic(critic_field="collection_name", weight=0.2),
|
||||
SimilarityCritic(critic_field="pipeline", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
[build-system]
|
||||
requires = [ "hatchling",]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arcade_mongodb"
|
||||
version = "0.3.0"
|
||||
description = "Tools to query and explore a MongoDB database"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-mcp-server>=1.17.0,<2.0.0",
|
||||
"pymongo>=4.10.1",
|
||||
"pydantic>=2.11.7",
|
||||
"motor>=3.6.0",
|
||||
]
|
||||
[[project.authors]]
|
||||
name = "evantahler"
|
||||
email = "support@arcade.dev"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"arcade-mcp[all]>=1.2.0,<2.0.0",
|
||||
"pytest>=8.3.0,<8.4.0",
|
||||
"pytest-cov>=4.0.0,<4.1.0",
|
||||
"pytest-mock>=3.11.1,<3.12.0",
|
||||
"pytest-asyncio>=0.24.0,<0.25.0",
|
||||
"mypy>=1.5.1,<1.6.0",
|
||||
"pre-commit>=3.4.0,<3.5.0",
|
||||
"tox>=4.11.1,<4.12.0",
|
||||
"ruff>=0.7.4,<0.8.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
arcade-mongodb = "arcade_mongodb.__main__:main"
|
||||
arcade_mongodb = "arcade_mongodb.__main__:main"
|
||||
|
||||
# Use local path sources for arcade libs when working locally
|
||||
[tool.uv.sources]
|
||||
arcade-mcp = { path = "../../", editable = true }
|
||||
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
files = [ "arcade_mongodb/**/*.py",]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
check_untyped_defs = "True"
|
||||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = [ "tests",]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [ "arcade_mongodb",]
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from os import environ
|
||||
|
||||
import pytest_asyncio
|
||||
from arcade_mongodb.database_engine import DatabaseEngine
|
||||
|
||||
TEST_MONGODB_CONNECTION_STRING = (
|
||||
environ.get("TEST_MONGODB_CONNECTION_STRING") or "mongodb://localhost:27017"
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def restore_database():
|
||||
"""Restore the database from the dump before each test."""
|
||||
|
||||
dump_file = f"{os.path.dirname(__file__)}/dump.js"
|
||||
|
||||
# Execute the MongoDB dump script to restore test data
|
||||
mongosh_path = shutil.which("mongosh")
|
||||
if not mongosh_path:
|
||||
raise RuntimeError("mongosh executable not found in PATH")
|
||||
|
||||
result = subprocess.run(
|
||||
[mongosh_path, TEST_MONGODB_CONNECTION_STRING, dump_file],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Error loading test data: {result.stderr}")
|
||||
raise RuntimeError(f"Failed to load test data: {result.stderr}")
|
||||
|
||||
yield # This allows tests to run
|
||||
|
||||
# Optional cleanup could go here if needed
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def cleanup_engines():
|
||||
"""Clean up database engines after each test to prevent connection leaks."""
|
||||
yield
|
||||
await DatabaseEngine.cleanup()
|
||||
|
|
@ -1,378 +0,0 @@
|
|||
// MongoDB test data dump - equivalent to PostgreSQL dump.sql
|
||||
// This script sets up test data for the MongoDB toolkit
|
||||
|
||||
// Switch to test database
|
||||
use('test_database');
|
||||
|
||||
// Clear existing data
|
||||
db.users.drop();
|
||||
db.messages.drop();
|
||||
|
||||
// Create users collection with data
|
||||
db.users.insertMany([
|
||||
{
|
||||
_id: 1,
|
||||
name: 'Alice',
|
||||
email: 'alice@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
|
||||
created_at: new Date('2024-09-01T20:49:38.759Z'),
|
||||
updated_at: new Date('2024-09-02T03:49:39.927Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 2,
|
||||
name: 'Bob',
|
||||
email: 'bob@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
|
||||
created_at: new Date('2024-09-02T17:49:23.377Z'),
|
||||
updated_at: new Date('2024-09-02T17:49:23.377Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 3,
|
||||
name: 'Charlie',
|
||||
email: 'charlie@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
|
||||
created_at: new Date('2024-09-03T10:30:15.123Z'),
|
||||
updated_at: new Date('2024-09-03T10:30:15.123Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 4,
|
||||
name: 'Diana',
|
||||
email: 'diana@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123',
|
||||
created_at: new Date('2024-09-04T14:20:30.654Z'),
|
||||
updated_at: new Date('2024-09-04T14:20:30.654Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 5,
|
||||
name: 'Evan',
|
||||
email: 'evan@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456',
|
||||
created_at: new Date('2024-09-05T09:15:45.987Z'),
|
||||
updated_at: new Date('2024-09-05T09:15:45.987Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 6,
|
||||
name: 'Fiona',
|
||||
email: 'fiona@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789',
|
||||
created_at: new Date('2024-09-06T16:45:12.345Z'),
|
||||
updated_at: new Date('2024-09-06T16:45:12.345Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 7,
|
||||
name: 'George',
|
||||
email: 'george@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012',
|
||||
created_at: new Date('2024-09-07T11:30:25.876Z'),
|
||||
updated_at: new Date('2024-09-07T11:30:25.876Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 8,
|
||||
name: 'Helen',
|
||||
email: 'helen@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345',
|
||||
created_at: new Date('2024-09-08T13:25:40.234Z'),
|
||||
updated_at: new Date('2024-09-08T13:25:40.234Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 9,
|
||||
name: 'Ian',
|
||||
email: 'ian@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678',
|
||||
created_at: new Date('2024-09-09T08:40:55.765Z'),
|
||||
updated_at: new Date('2024-09-09T08:40:55.765Z'),
|
||||
status: 'active'
|
||||
},
|
||||
{
|
||||
_id: 10,
|
||||
name: 'Julia',
|
||||
email: 'julia@example.com',
|
||||
password_hash: '$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901',
|
||||
created_at: new Date('2024-09-10T15:55:18.123Z'),
|
||||
updated_at: new Date('2024-09-10T15:55:18.123Z'),
|
||||
status: 'active'
|
||||
}
|
||||
]);
|
||||
|
||||
// Create messages collection with data
|
||||
db.messages.insertMany([
|
||||
// User 1 (Alice) - 3 messages
|
||||
{
|
||||
_id: 1,
|
||||
body: 'Hello everyone!',
|
||||
user_id: 1,
|
||||
created_at: new Date('2025-01-10T10:00:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T10:00:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 2,
|
||||
body: 'How is everyone doing today?',
|
||||
user_id: 1,
|
||||
created_at: new Date('2025-01-10T11:30:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T11:30:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 3,
|
||||
body: 'Great to see you all here!',
|
||||
user_id: 1,
|
||||
created_at: new Date('2025-01-10T14:15:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T14:15:00.000Z')
|
||||
},
|
||||
// User 2 (Bob) - 2 messages
|
||||
{
|
||||
_id: 4,
|
||||
body: 'Hi Alice! Doing well, thanks for asking.',
|
||||
user_id: 2,
|
||||
created_at: new Date('2025-01-10T11:35:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T11:35:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 5,
|
||||
body: 'Anyone up for a game later?',
|
||||
user_id: 2,
|
||||
created_at: new Date('2025-01-10T16:20:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T16:20:00.000Z')
|
||||
},
|
||||
// User 3 (Charlie) - 3 messages
|
||||
{
|
||||
_id: 6,
|
||||
body: 'Count me in for the game!',
|
||||
user_id: 3,
|
||||
created_at: new Date('2025-01-10T16:25:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T16:25:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 7,
|
||||
body: 'What time works for everyone?',
|
||||
user_id: 3,
|
||||
created_at: new Date('2025-01-10T16:30:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T16:30:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 8,
|
||||
body: 'I can play around 8 PM',
|
||||
user_id: 3,
|
||||
created_at: new Date('2025-01-10T17:00:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T17:00:00.000Z')
|
||||
},
|
||||
// User 4 (Diana) - 2 messages
|
||||
{
|
||||
_id: 9,
|
||||
body: '8 PM works for me too!',
|
||||
user_id: 4,
|
||||
created_at: new Date('2025-01-10T17:05:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T17:05:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 10,
|
||||
body: 'What game should we play?',
|
||||
user_id: 4,
|
||||
created_at: new Date('2025-01-10T17:10:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T17:10:00.000Z')
|
||||
},
|
||||
// User 5 (Evan) - 13 messages (including 10 additional ones)
|
||||
{
|
||||
_id: 11,
|
||||
body: 'I suggest we try the new arcade game!',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T17:15:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T17:15:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 12,
|
||||
body: 'It has great multiplayer features',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T17:20:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T17:20:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 13,
|
||||
body: 'Perfect timing for a weekend session',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T18:00:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T18:00:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 26,
|
||||
body: 'Just finished setting up the game server!',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:00:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:00:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 27,
|
||||
body: 'Everyone should be able to connect now',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:05:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:05:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 28,
|
||||
body: 'I added some custom maps too',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:10:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:10:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 29,
|
||||
body: 'The graphics look amazing on this new version',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:15:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:15:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 30,
|
||||
body: 'Hope you all enjoy the new features',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:20:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:20:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 31,
|
||||
body: 'I also set up a leaderboard system',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:25:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:25:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 32,
|
||||
body: 'We can track high scores now',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:30:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:30:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 33,
|
||||
body: 'The game supports up to 8 players simultaneously',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:35:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:35:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 34,
|
||||
body: 'I tested it earlier and it runs smoothly',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:40:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:40:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 35,
|
||||
body: 'Cannot wait to see everyone online tonight!',
|
||||
user_id: 5,
|
||||
created_at: new Date('2025-01-10T20:45:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T20:45:00.000Z')
|
||||
},
|
||||
// User 6 (Fiona) - 2 messages
|
||||
{
|
||||
_id: 14,
|
||||
body: 'Sounds like fun! I love arcade games.',
|
||||
user_id: 6,
|
||||
created_at: new Date('2025-01-10T18:05:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T18:05:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 15,
|
||||
body: 'Should I bring snacks?',
|
||||
user_id: 6,
|
||||
created_at: new Date('2025-01-10T18:10:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T18:10:00.000Z')
|
||||
},
|
||||
// User 7 (George) - 3 messages
|
||||
{
|
||||
_id: 16,
|
||||
body: 'Snacks are always welcome!',
|
||||
user_id: 7,
|
||||
created_at: new Date('2025-01-10T18:15:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T18:15:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 17,
|
||||
body: 'I can bring some drinks',
|
||||
user_id: 7,
|
||||
created_at: new Date('2025-01-10T18:20:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T18:20:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 18,
|
||||
body: 'This is going to be awesome',
|
||||
user_id: 7,
|
||||
created_at: new Date('2025-01-10T19:00:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:00:00.000Z')
|
||||
},
|
||||
// User 8 (Helen) - 2 messages
|
||||
{
|
||||
_id: 19,
|
||||
body: 'I agree! Cannot wait for the game night.',
|
||||
user_id: 8,
|
||||
created_at: new Date('2025-01-10T19:05:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:05:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 20,
|
||||
body: 'Should we set up a Discord call?',
|
||||
user_id: 8,
|
||||
created_at: new Date('2025-01-10T19:10:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:10:00.000Z')
|
||||
},
|
||||
// User 9 (Ian) - 3 messages
|
||||
{
|
||||
_id: 21,
|
||||
body: 'Discord would be perfect for voice chat',
|
||||
user_id: 9,
|
||||
created_at: new Date('2025-01-10T19:15:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:15:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 22,
|
||||
body: 'I will create a server for us',
|
||||
user_id: 9,
|
||||
created_at: new Date('2025-01-10T19:20:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:20:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 23,
|
||||
body: 'Link will be shared in a few minutes',
|
||||
user_id: 9,
|
||||
created_at: new Date('2025-01-10T19:25:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:25:00.000Z')
|
||||
},
|
||||
// User 10 (Julia) - 2 messages
|
||||
{
|
||||
_id: 24,
|
||||
body: 'Thanks Ian! You are the best.',
|
||||
user_id: 10,
|
||||
created_at: new Date('2025-01-10T19:30:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:30:00.000Z')
|
||||
},
|
||||
{
|
||||
_id: 25,
|
||||
body: 'See you all at 8 PM!',
|
||||
user_id: 10,
|
||||
created_at: new Date('2025-01-10T19:35:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:35:00.000Z')
|
||||
},{
|
||||
_id: 99,
|
||||
body: 'You are a mean jerk, you shithead!',
|
||||
user_id: 10,
|
||||
created_at: new Date('2025-01-10T19:35:00.000Z'),
|
||||
updated_at: new Date('2025-01-10T19:35:00.000Z')
|
||||
}
|
||||
]);
|
||||
|
||||
// Create indexes for better performance (equivalent to PostgreSQL indexes)
|
||||
db.users.createIndex({ "name": 1 }, { unique: true });
|
||||
db.users.createIndex({ "email": 1 }, { unique: true });
|
||||
db.messages.createIndex({ "user_id": 1 });
|
||||
db.messages.createIndex({ "created_at": 1 });
|
||||
|
||||
print("MongoDB test data setup completed successfully!");
|
||||
print("Users collection: " + db.users.countDocuments());
|
||||
print("Messages collection: " + db.messages.countDocuments());
|
||||
|
|
@ -1,221 +0,0 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from arcade_core.errors import ToolExecutionError
|
||||
from arcade_mcp_server import Context
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from arcade_mongodb.tools.mongodb import aggregate_documents, count_documents, find_documents
|
||||
|
||||
from .conftest import TEST_MONGODB_CONNECTION_STRING
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = MagicMock(spec=Context)
|
||||
context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_in_filter_dict(mock_context) -> None:
|
||||
"""Test that invalid JSON in filter_dict returns a reasonable error message."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"status": "active",}', # Invalid JSON - trailing comma
|
||||
limit=1,
|
||||
)
|
||||
|
||||
# Check that this is a JSON validation error
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid JSON in filter_dict" in error_message
|
||||
|
||||
# Check that the developer message contains helpful information
|
||||
assert "filter_dict" in exc_info.value.developer_message
|
||||
assert "JSON" in exc_info.value.additional_prompt_content
|
||||
|
||||
# Check that the original JSON error is in the cause chain
|
||||
assert exc_info.value.__cause__ is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_in_projection(mock_context) -> None:
|
||||
"""Test that invalid JSON in projection returns a reasonable error message."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
projection='{"name": 1, "email": 1,}', # Invalid JSON - trailing comma
|
||||
limit=1,
|
||||
)
|
||||
|
||||
# Check that this is a JSON validation error
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid JSON in projection" in error_message
|
||||
|
||||
# Check that the error message is helpful
|
||||
assert "projection" in exc_info.value.developer_message
|
||||
assert "JSON" in exc_info.value.additional_prompt_content
|
||||
|
||||
# Check that the original JSON error is in the cause chain
|
||||
assert exc_info.value.__cause__ is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_in_sort(mock_context) -> None:
|
||||
"""Test that invalid JSON in sort returns a reasonable error message."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
sort=['{"field": "name", "direction": 1,}'], # Invalid JSON - trailing comma
|
||||
limit=1,
|
||||
)
|
||||
|
||||
# Check that this is a JSON validation error
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid JSON in sort" in error_message
|
||||
|
||||
# Check that the error message is helpful
|
||||
assert "sort" in exc_info.value.developer_message
|
||||
assert "JSON" in exc_info.value.additional_prompt_content
|
||||
|
||||
# Check that the original JSON error is in the cause chain
|
||||
assert exc_info.value.__cause__ is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_in_count_filter(mock_context) -> None:
|
||||
"""Test that invalid JSON in count_documents filter returns a reasonable error message."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await count_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"status": "active",}', # Invalid JSON - trailing comma
|
||||
)
|
||||
|
||||
# Check that this is a JSON validation error
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid JSON in filter_dict" in error_message
|
||||
|
||||
# Check that the error message is helpful
|
||||
assert "filter_dict" in exc_info.value.developer_message
|
||||
assert "JSON" in exc_info.value.additional_prompt_content
|
||||
|
||||
# Check that the original JSON error is in the cause chain
|
||||
assert exc_info.value.__cause__ is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_in_pipeline(mock_context) -> None:
|
||||
"""Test that invalid JSON in aggregation pipeline returns a reasonable error message."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await aggregate_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
pipeline=['{"$match": {"status": "active",}}'], # Invalid JSON - trailing comma
|
||||
)
|
||||
|
||||
# Check that this is a JSON validation error
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid JSON in pipeline" in error_message
|
||||
|
||||
# Check that the error message is helpful
|
||||
assert "pipeline" in exc_info.value.developer_message
|
||||
assert "JSON" in exc_info.value.additional_prompt_content
|
||||
|
||||
# Check that the original JSON error is in the cause chain
|
||||
assert exc_info.value.__cause__ is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_json_string(mock_context) -> None:
|
||||
"""Test various malformed JSON strings return reasonable error messages."""
|
||||
test_cases = [
|
||||
('{"unclosed": "string}', "Unterminated string"),
|
||||
('{"missing_quotes": value}', "Expecting"),
|
||||
('{missing_closing_brace: "value"}', "Expecting"),
|
||||
('[{"array": "with"}, {"missing": }]', "Expecting"),
|
||||
]
|
||||
|
||||
for invalid_json, expected_error_fragment in test_cases:
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict=invalid_json,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
# Check that this is a JSON validation error
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid JSON in filter_dict" in error_message
|
||||
|
||||
# Check that specific error details are included when expected
|
||||
if expected_error_fragment:
|
||||
assert (
|
||||
expected_error_fragment in error_message
|
||||
or expected_error_fragment in exc_info.value.developer_message
|
||||
)
|
||||
|
||||
# Ensure helpful context is provided
|
||||
assert "filter_dict" in exc_info.value.developer_message
|
||||
assert "JSON" in exc_info.value.additional_prompt_content
|
||||
assert "escaping" in exc_info.value.additional_prompt_content
|
||||
|
||||
# Check that the original JSON error is in the cause chain
|
||||
assert exc_info.value.__cause__ is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_json_does_not_error(mock_context) -> None:
|
||||
"""Test that valid JSON does not raise JSON parsing errors."""
|
||||
# This should not raise a JSON parsing error (might raise other errors, but not JSON-related)
|
||||
try:
|
||||
result = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"status": "active"}',
|
||||
projection='{"name": 1, "_id": 0}',
|
||||
sort=['{"field": "name", "direction": 1}'],
|
||||
limit=1,
|
||||
)
|
||||
# If we get here, JSON parsing succeeded
|
||||
assert isinstance(result, list)
|
||||
except (ToolExecutionError, RetryableToolError) as e:
|
||||
# If we get an error, it should not be about JSON parsing
|
||||
# Check both the outer error and any nested error
|
||||
error_message = str(e)
|
||||
nested_message = str(e.__cause__) if e.__cause__ else ""
|
||||
assert "Invalid JSON" not in error_message
|
||||
assert "Invalid JSON" not in nested_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_keys_are_valid_json(mock_context) -> None:
|
||||
"""Test that duplicate keys in JSON are valid (Python JSON allows this)."""
|
||||
# This should NOT raise a JSON parsing error because duplicate keys are valid JSON
|
||||
try:
|
||||
result = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"duplicate": "key", "duplicate": "key"}', # Valid JSON - last value wins
|
||||
limit=1,
|
||||
)
|
||||
# If we get here, JSON parsing succeeded (might get empty results, but no JSON error)
|
||||
assert isinstance(result, list)
|
||||
except (ToolExecutionError, RetryableToolError) as e:
|
||||
# If we get an error, it should not be about JSON parsing
|
||||
error_message = str(e)
|
||||
nested_message = str(e.__cause__) if e.__cause__ else ""
|
||||
assert "Invalid JSON" not in error_message
|
||||
assert "Invalid JSON" not in nested_message
|
||||
|
|
@ -1,292 +0,0 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from arcade_mcp_server import Context
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from arcade_mongodb.database_engine import DatabaseEngine
|
||||
from arcade_mongodb.tools.mongodb import (
|
||||
# UserStatus,
|
||||
aggregate_documents,
|
||||
count_documents,
|
||||
discover_collections,
|
||||
discover_databases,
|
||||
find_documents,
|
||||
get_collection_schema,
|
||||
# update_user_status,
|
||||
)
|
||||
|
||||
from .conftest import TEST_MONGODB_CONNECTION_STRING
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = MagicMock(spec=Context)
|
||||
context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_databases(mock_context) -> None:
|
||||
databases = await discover_databases(mock_context)
|
||||
assert isinstance(databases, list)
|
||||
# Should not include system databases like admin, config, local
|
||||
for db in databases:
|
||||
assert db not in ["admin", "config", "local"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_collections(mock_context) -> None:
|
||||
collections = await discover_collections(mock_context, "test_database")
|
||||
assert "users" in collections
|
||||
assert "messages" in collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_collection_schema(mock_context) -> None:
|
||||
schema_result = await get_collection_schema(
|
||||
mock_context, "test_database", "users", sample_size=10
|
||||
)
|
||||
|
||||
assert "schema" in schema_result
|
||||
assert "total_documents_sampled" in schema_result
|
||||
assert schema_result["total_documents_sampled"] == 10 # We have 10 users
|
||||
|
||||
schema = schema_result["schema"]
|
||||
assert "_id" in schema
|
||||
assert "name" in schema
|
||||
assert "email" in schema
|
||||
assert "password_hash" in schema
|
||||
assert "status" in schema
|
||||
assert "created_at" in schema
|
||||
assert "updated_at" in schema
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_documents_basic(mock_context) -> None:
|
||||
# Find all users
|
||||
result = await find_documents(
|
||||
mock_context, database_name="test_database", collection_name="users", limit=10
|
||||
)
|
||||
|
||||
assert len(result) == 10
|
||||
# Parse JSON strings to check contents
|
||||
docs = [json.loads(doc_str) for doc_str in result]
|
||||
assert all("name" in doc for doc in docs)
|
||||
assert all("email" in doc for doc in docs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_documents_with_filter(mock_context) -> None:
|
||||
# Find active users
|
||||
result = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"status": "active"}',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert len(result) == 10 # All users in dump are active
|
||||
docs = [json.loads(doc_str) for doc_str in result]
|
||||
assert all(doc["status"] == "active" for doc in docs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_documents_with_projection(mock_context) -> None:
|
||||
# Find users with only name and email
|
||||
result = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
projection='{"name": 1, "email": 1, "_id": 0}',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert len(result) == 10
|
||||
docs = [json.loads(doc_str) for doc_str in result]
|
||||
for doc in docs:
|
||||
assert "name" in doc
|
||||
assert "email" in doc
|
||||
assert "_id" not in doc
|
||||
assert "password_hash" not in doc
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_documents_with_sort(mock_context) -> None:
|
||||
# Find users sorted by _id descending
|
||||
result = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
sort=['{"field": "_id", "direction": -1}'],
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
docs = [json.loads(doc_str) for doc_str in result]
|
||||
ids = [doc["_id"] for doc in docs]
|
||||
assert ids == [10, 9, 8] # Descending order
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_documents(mock_context) -> None:
|
||||
# Count all users
|
||||
count = await count_documents(
|
||||
mock_context, database_name="test_database", collection_name="users"
|
||||
)
|
||||
assert count == 10
|
||||
|
||||
# Count active users
|
||||
active_count = await count_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"status": "active"}',
|
||||
)
|
||||
assert active_count == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_documents(mock_context) -> None:
|
||||
# Aggregate to count users by status
|
||||
pipeline = ['{"$group": {"_id": "$status", "count": {"$sum": 1}}}', '{"$sort": {"count": -1}}']
|
||||
|
||||
result = await aggregate_documents(
|
||||
mock_context, database_name="test_database", collection_name="users", pipeline=pipeline
|
||||
)
|
||||
|
||||
assert len(result) == 1 # Only active users
|
||||
# Should be sorted by count descending
|
||||
doc = json.loads(result[0])
|
||||
assert doc["_id"] == "active"
|
||||
assert doc["count"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_documents_with_skip_and_limit(mock_context) -> None:
|
||||
# Test pagination
|
||||
result1 = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
sort=['{"field": "name", "direction": 1}'],
|
||||
limit=2,
|
||||
skip=0,
|
||||
)
|
||||
|
||||
result2 = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
sort=['{"field": "name", "direction": 1}'],
|
||||
limit=2,
|
||||
skip=2,
|
||||
)
|
||||
|
||||
assert len(result1) == 2
|
||||
assert len(result2) == 2
|
||||
|
||||
docs1 = [json.loads(doc_str) for doc_str in result1]
|
||||
docs2 = [json.loads(doc_str) for doc_str in result2]
|
||||
|
||||
assert docs1[0]["name"] == "Alice"
|
||||
assert docs1[1]["name"] == "Bob"
|
||||
assert docs2[0]["name"] == "Charlie"
|
||||
assert docs2[1]["name"] == "Diana"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_invalid_database(mock_context) -> None:
|
||||
# Test with non-existent database - should not error but return empty results
|
||||
collections = await discover_collections(mock_context, "nonexistent_database")
|
||||
assert collections == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_invalid_collection(mock_context) -> None:
|
||||
# Test with non-existent collection
|
||||
result = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="nonexistent_collection",
|
||||
limit=10,
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sanitize_query_params() -> None:
|
||||
# Test parameter validation
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
DatabaseEngine.sanitize_query_params("", "users", {}, None, None, 10, 0)
|
||||
assert "Database name is required" in str(e.value)
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
DatabaseEngine.sanitize_query_params("test_db", "", {}, None, None, 10, 0)
|
||||
assert "Collection name is required" in str(e.value)
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
DatabaseEngine.sanitize_query_params(
|
||||
"test_db", "users", {}, None, None, 2000, 0
|
||||
) # Too high limit
|
||||
assert "Limit is too high" in str(e.value)
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_update_user_status_success(mock_context) -> None:
|
||||
# """Test successful user status update."""
|
||||
# # First, find a user to update
|
||||
# users = await find_documents(
|
||||
# mock_context, database_name="test_database", collection_name="users", limit=1
|
||||
# )
|
||||
# assert len(users) > 0
|
||||
|
||||
# user_doc = json.loads(users[0])
|
||||
# user_id = user_doc["_id"]
|
||||
|
||||
# # Update user status to inactive
|
||||
# result = await update_user_status(
|
||||
# mock_context,
|
||||
# database_name="test_database",
|
||||
# collection_name="users",
|
||||
# user_id=user_id,
|
||||
# status=UserStatus.INACTIVE,
|
||||
# )
|
||||
|
||||
# assert result["success"] is True
|
||||
# assert result["user_id"] == user_id
|
||||
# assert result["new_status"] == "inactive"
|
||||
# assert result["matched_count"] == 1
|
||||
# assert result["modified_count"] == 1
|
||||
|
||||
# # Verify the update by finding the user again
|
||||
# # Convert user_id to int since the test data uses integer IDs
|
||||
# user_id_int = int(user_id)
|
||||
# updated_users = await find_documents(
|
||||
# mock_context,
|
||||
# database_name="test_database",
|
||||
# collection_name="users",
|
||||
# filter_dict=f'{{"_id": {user_id_int}}}',
|
||||
# limit=1,
|
||||
# )
|
||||
# assert len(updated_users) == 1
|
||||
# updated_user = json.loads(updated_users[0])
|
||||
# assert updated_user["status"] == "inactive"
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_update_user_status_user_not_found(mock_context) -> None:
|
||||
# """Test updating status for non-existent user."""
|
||||
# result = await update_user_status(
|
||||
# mock_context,
|
||||
# database_name="test_database",
|
||||
# collection_name="users",
|
||||
# user_id="nonexistent_user_id",
|
||||
# status=UserStatus.BANNED,
|
||||
# )
|
||||
|
||||
# assert result["success"] is False
|
||||
# assert "No user found with _id" in result["message"]
|
||||
# assert result["matched_count"] == 0
|
||||
# assert result["modified_count"] == 0
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# install mongosh to load sample data
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y wget gnupg
|
||||
wget -qO - https://www.mongodb.org/static/pgp/server-6.0.asc | sudo apt-key add -
|
||||
echo "deb [ arch=amd64,arm64 ] https://repo.mongodb.org/apt/ubuntu jammy/mongodb-org/6.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-6.0.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y mongodb-mongosh
|
||||
|
||||
# Run mongodb container
|
||||
docker run -d --name some-mongodb-server -p 27017:27017 mongo
|
||||
|
|
@ -1,248 +0,0 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from arcade_mcp_server import Context
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from arcade_mongodb.tools.mongodb import aggregate_documents, count_documents, find_documents
|
||||
|
||||
from .conftest import TEST_MONGODB_CONNECTION_STRING
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = MagicMock(spec=Context)
|
||||
context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_dict_blocks_set_operation(mock_context) -> None:
|
||||
"""Test that $set operation in filter_dict is blocked."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"$set": {"status": "modified"}}', # Write operation
|
||||
limit=1,
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Write operation '$set' not allowed in filter_dict" in error_message
|
||||
assert "$set" in exc_info.value.developer_message
|
||||
assert "Only read operations are allowed" in exc_info.value.developer_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_dict_blocks_update_operations(mock_context) -> None:
|
||||
"""Test that various update operations in filter_dict are blocked."""
|
||||
update_ops = ["$inc", "$unset", "$push", "$pull", "$rename", "$currentDate"]
|
||||
|
||||
for op in update_ops:
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict=f'{{"{op}": {{"field": "value"}}}}',
|
||||
limit=1,
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert f"Write operation '{op}' not allowed in filter_dict" in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_projection_blocks_write_operations(mock_context) -> None:
|
||||
"""Test that write operations in projection are blocked."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
projection='{"$set": {"modified": true}, "name": 1}', # Write operation in projection
|
||||
limit=1,
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Write operation '$set' not allowed in projection" in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sort_blocks_write_operations(mock_context) -> None:
|
||||
"""Test that write operations in sort are blocked."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
sort=['{"field": "name", "direction": 1, "$inc": {"counter": 1}}'], # Write op in sort
|
||||
limit=1,
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Write operation '$inc' not allowed in sort[0]" in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_filter_blocks_write_operations(mock_context) -> None:
|
||||
"""Test that write operations in count filter are blocked."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await count_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"status": "active", "$unset": {"password": ""}}', # Write operation
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Write operation '$unset' not allowed in filter_dict" in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_pipeline_blocks_out_stage(mock_context) -> None:
|
||||
"""Test that $out stage in aggregation pipeline is blocked."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await aggregate_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
pipeline=[
|
||||
'{"$match": {"status": "active"}}',
|
||||
'{"$out": "output_collection"}', # Write stage
|
||||
],
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Write stage '$out' not allowed in pipeline" in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_pipeline_blocks_merge_stage(mock_context) -> None:
|
||||
"""Test that $merge stage in aggregation pipeline is blocked."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await aggregate_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
pipeline=[
|
||||
'{"$match": {"status": "active"}}',
|
||||
'{"$merge": {"into": "target_collection"}}', # Write stage
|
||||
],
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Write stage '$merge' not allowed in pipeline" in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_where_operator_blocked(mock_context) -> None:
|
||||
"""Test that $where operator is blocked for security reasons."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"$where": "this.name == \'admin\'"}', # JavaScript execution
|
||||
limit=1,
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "JavaScript execution operator '$where' not allowed in filter_dict" in error_message
|
||||
assert (
|
||||
"JavaScript execution is not allowed for security reasons"
|
||||
in exc_info.value.developer_message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_write_operations_blocked(mock_context) -> None:
|
||||
"""Test that nested write operations are blocked."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"status": "active", "nested": {"$set": {"field": "value"}}}', # Nested write op
|
||||
limit=1,
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Write operation '$set' not allowed in filter_dict" in error_message
|
||||
assert "nested.$set" in exc_info.value.developer_message # Should show the path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_read_operations_allowed(mock_context) -> None:
|
||||
"""Test that valid read operations are allowed."""
|
||||
# These should not raise write operation errors
|
||||
try:
|
||||
# Test query operators
|
||||
result = await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict='{"status": {"$in": ["active", "inactive"]}, "name": {"$regex": "^A"}}',
|
||||
projection='{"name": 1, "email": 1, "_id": 0}',
|
||||
sort=['{"field": "name", "direction": 1}'],
|
||||
limit=1,
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
|
||||
# Test aggregation pipeline with read-only stages
|
||||
pipeline_result = await aggregate_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
pipeline=[
|
||||
'{"$match": {"status": "active"}}',
|
||||
'{"$group": {"_id": "$status", "count": {"$sum": 1}}}',
|
||||
'{"$sort": {"count": -1}}',
|
||||
],
|
||||
)
|
||||
assert isinstance(pipeline_result, list)
|
||||
|
||||
except RetryableToolError as e:
|
||||
# If we get an error, it should not be about write operations
|
||||
error_message = str(e)
|
||||
nested_message = str(e.__cause__) if e.__cause__ else ""
|
||||
assert "Write operation" not in error_message
|
||||
assert "Write stage" not in error_message
|
||||
assert "Write operation" not in nested_message
|
||||
assert "Write stage" not in nested_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_array_write_operations_blocked(mock_context) -> None:
|
||||
"""Test that array write operations are blocked."""
|
||||
array_write_ops = ["$addToSet", "$pop", "$pull", "$push", "$pullAll"]
|
||||
|
||||
for op in array_write_ops:
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await find_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
filter_dict=f'{{"{op}": {{"tags": "new_tag"}}}}',
|
||||
limit=1,
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert f"Write operation '{op}' not allowed in filter_dict" in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_stage_content_validated(mock_context) -> None:
|
||||
"""Test that content within aggregation stages is also validated for write operations."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await aggregate_documents(
|
||||
mock_context,
|
||||
database_name="test_database",
|
||||
collection_name="users",
|
||||
pipeline=[
|
||||
'{"$match": {"status": "active", "$set": {"modified": true}}}' # Write op inside $match
|
||||
],
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Write operation '$set' not allowed in pipeline[0].$match" in error_message
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ github Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the uv environment and install all packages with dependencies
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras --no-sources
|
||||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: install-local
|
||||
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras
|
||||
@uv run pre-commit install
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
uv build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
coverage report
|
||||
@echo "Generating coverage report"
|
||||
coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
uv version --bump patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@echo "🚀 Linting code: Running pre-commit"
|
||||
@uv run pre-commit run -a
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@uv run mypy --config-file=pyproject.toml
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import sys
|
||||
from typing import cast
|
||||
|
||||
from arcade_mcp_server import MCPApp
|
||||
from arcade_mcp_server.mcp_app import TransportType
|
||||
|
||||
import arcade_postgres
|
||||
|
||||
app = MCPApp(
|
||||
name="PostgreSQL",
|
||||
instructions=(
|
||||
"Use this server when you need to interact with PostgreSQL to help users "
|
||||
"query, explore, and manage their PostgreSQL databases."
|
||||
),
|
||||
)
|
||||
|
||||
app.add_tools_from_module(arcade_postgres)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
transport = sys.argv[1] if len(sys.argv) > 1 else "stdio"
|
||||
host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1"
|
||||
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
|
||||
|
||||
app.run(transport=cast(TransportType, transport), host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,180 +0,0 @@
|
|||
from typing import Any, ClassVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
|
||||
MAX_ROWS_RETURNED = 1000
|
||||
TEST_QUERY = "SELECT 1"
|
||||
|
||||
|
||||
class DatabaseEngine:
|
||||
_instance: ClassVar[None] = None
|
||||
_engines: ClassVar[dict[str, AsyncEngine]] = {}
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, connection_string: str) -> AsyncEngine:
|
||||
parsed_url = urlparse(connection_string)
|
||||
|
||||
# TODO: something strange with sslmode= and friends
|
||||
# query_params = parse_qs(parsed_url.query)
|
||||
# query_params = {
|
||||
# k: v[0] for k, v in query_params.items()
|
||||
# } # assume one value allowed for each query param
|
||||
|
||||
async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}"
|
||||
key = f"{async_connection_string}"
|
||||
if key not in cls._engines:
|
||||
cls._engines[key] = create_async_engine(async_connection_string)
|
||||
|
||||
# try a simple query to see if the connection is valid
|
||||
try:
|
||||
async with cls._engines[key].connect() as connection:
|
||||
await connection.execute(text(TEST_QUERY))
|
||||
return cls._engines[key]
|
||||
except Exception:
|
||||
await cls._engines[key].dispose()
|
||||
|
||||
# try again
|
||||
try:
|
||||
async with cls._engines[key].connect() as connection:
|
||||
await connection.execute(text(TEST_QUERY))
|
||||
return cls._engines[key]
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Connection failed: {e}",
|
||||
developer_message="Connection to postgres failed.",
|
||||
additional_prompt_content="Check the connection string and try again.",
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
async def get_engine(cls, connection_string: str) -> Any:
|
||||
engine = await cls.get_instance(connection_string)
|
||||
|
||||
class ConnectionContextManager:
|
||||
def __init__(self, engine: AsyncEngine) -> None:
|
||||
self.engine = engine
|
||||
|
||||
async def __aenter__(self) -> AsyncEngine:
|
||||
return self.engine
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
# Connection cleanup is handled by the async context manager
|
||||
pass
|
||||
|
||||
return ConnectionContextManager(engine)
|
||||
|
||||
@classmethod
|
||||
async def cleanup(cls) -> None:
|
||||
"""Clean up all cached engines. Call this when shutting down."""
|
||||
for engine in cls._engines.values():
|
||||
await engine.dispose()
|
||||
cls._engines.clear()
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the engine cache without disposing engines. Use with caution."""
|
||||
cls._engines.clear()
|
||||
|
||||
@classmethod
|
||||
def sanitize_query( # noqa: C901
|
||||
cls,
|
||||
select_clause: str,
|
||||
from_clause: str,
|
||||
limit: int,
|
||||
offset: int,
|
||||
join_clause: str | None,
|
||||
where_clause: str | None,
|
||||
having_clause: str | None,
|
||||
group_by_clause: str | None,
|
||||
order_by_clause: str | None,
|
||||
with_clause: str | None,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
# Remove the leading keywords from the clauses if they are present
|
||||
if select_clause.strip().split(" ")[0].upper() == "SELECT":
|
||||
select_clause = select_clause.strip()[6:]
|
||||
|
||||
if from_clause.strip().split(" ")[0].upper() == "FROM":
|
||||
from_clause = from_clause.strip()[4:]
|
||||
|
||||
if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN":
|
||||
join_clause = join_clause.strip()[4:]
|
||||
|
||||
if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE":
|
||||
where_clause = where_clause.strip()[5:]
|
||||
|
||||
if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY":
|
||||
group_by_clause = group_by_clause.strip()[8:]
|
||||
|
||||
if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY":
|
||||
order_by_clause = order_by_clause.strip()[8:]
|
||||
|
||||
if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING":
|
||||
having_clause = having_clause.strip()[6:]
|
||||
|
||||
first_select_word = select_clause.strip().split(" ")[0].upper()
|
||||
if first_select_word in [
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"CREATE",
|
||||
"ALTER",
|
||||
"DROP",
|
||||
"TRUNCATE",
|
||||
"REINDEX",
|
||||
"VACUUM",
|
||||
"ANALYZE",
|
||||
"COMMENT",
|
||||
]:
|
||||
raise RetryableToolError(
|
||||
"Only SELECT queries are allowed.",
|
||||
)
|
||||
|
||||
if select_clause.strip() == "*":
|
||||
raise RetryableToolError(
|
||||
"Do not use * in the select clause. Use a comma separated list of columns you wish to return.",
|
||||
)
|
||||
|
||||
if limit > MAX_ROWS_RETURNED:
|
||||
raise RetryableToolError(
|
||||
f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.",
|
||||
)
|
||||
|
||||
if offset < 0:
|
||||
raise RetryableToolError(
|
||||
"Offset must be greater than or equal to 0.",
|
||||
developer_message="Offset must be greater than or equal to 0.",
|
||||
)
|
||||
|
||||
if limit <= 0:
|
||||
raise RetryableToolError(
|
||||
"Limit must be greater than 0.",
|
||||
developer_message="Limit must be greater than 0.",
|
||||
)
|
||||
|
||||
# Build query with identifiers directly interpolated, but use parameters for values
|
||||
parts = []
|
||||
if with_clause:
|
||||
parts.append(f"WITH {with_clause}")
|
||||
parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608
|
||||
if join_clause:
|
||||
parts.append(f"JOIN {join_clause}")
|
||||
if where_clause:
|
||||
parts.append(f"WHERE {where_clause}")
|
||||
if group_by_clause:
|
||||
parts.append(f"GROUP BY {group_by_clause}")
|
||||
if having_clause:
|
||||
parts.append(f"HAVING {having_clause}")
|
||||
if order_by_clause:
|
||||
parts.append(f"ORDER BY {order_by_clause}")
|
||||
parts.append("LIMIT :limit OFFSET :offset")
|
||||
query = " ".join(parts)
|
||||
|
||||
# Only use parameters for values, not identifiers
|
||||
parameters = {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
return query, parameters
|
||||
|
|
@ -1,300 +0,0 @@
|
|||
from typing import Annotated, Any
|
||||
|
||||
from arcade_mcp_server import Context, tool
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata
|
||||
from sqlalchemy import inspect, text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def discover_schemas(
|
||||
context: Context,
|
||||
) -> list[str]:
|
||||
"""Discover all the schemas in the postgres database."""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
|
||||
) as engine:
|
||||
schemas = await _get_schemas(engine)
|
||||
return schemas
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def discover_tables(
|
||||
context: Context,
|
||||
schema_name: Annotated[
|
||||
str, "The database schema to discover tables in (default value: 'public')"
|
||||
] = "public",
|
||||
) -> list[str]:
|
||||
"""Discover all the tables in the postgres database when the list of tables is not known.
|
||||
|
||||
ALWAYS use this tool before any other tool that requires a table name.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
|
||||
) as engine:
|
||||
tables = await _get_tables(engine, schema_name)
|
||||
return tables
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def get_table_schema(
|
||||
context: Context,
|
||||
schema_name: Annotated[str, "The database schema to get the table schema of"],
|
||||
table_name: Annotated[str, "The table to get the schema of"],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the schema/structure of a postgres table in the postgres database when the schema is not known, and the name of the table is provided.
|
||||
|
||||
This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the <DiscoverTables> tool.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
|
||||
) as engine:
|
||||
return await _get_table_schema(engine, schema_name, table_name)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"],
|
||||
metadata=ToolMetadata(
|
||||
behavior=Behavior(
|
||||
operations=[Operation.READ],
|
||||
read_only=True,
|
||||
destructive=False,
|
||||
idempotent=True,
|
||||
open_world=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
async def execute_select_query(
|
||||
context: Context,
|
||||
select_clause: Annotated[
|
||||
str,
|
||||
"This is the part of the SQL query that comes after the SELECT keyword wish a comma separated list of columns you wish to return. Do not include the SELECT keyword.",
|
||||
],
|
||||
from_clause: Annotated[
|
||||
str,
|
||||
"This is the part of the SQL query that comes after the FROM keyword. Do not include the FROM keyword.",
|
||||
],
|
||||
limit: Annotated[
|
||||
int,
|
||||
"The maximum number of rows to return. This is the LIMIT clause of the query. Default: 100.",
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int, "The number of rows to skip. This is the OFFSET clause of the query. Default: 0."
|
||||
] = 0,
|
||||
join_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the JOIN keyword. Do not include the JOIN keyword. If no join is needed, leave this blank.",
|
||||
] = None,
|
||||
where_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the WHERE keyword. Do not include the WHERE keyword. If no where clause is needed, leave this blank.",
|
||||
] = None,
|
||||
having_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the HAVING keyword. Do not include the HAVING keyword. If no having clause is needed, leave this blank.",
|
||||
] = None,
|
||||
group_by_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the GROUP BY keyword. Do not include the GROUP BY keyword. If no group by clause is needed, leave this blank.",
|
||||
] = None,
|
||||
order_by_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the ORDER BY keyword. Do not include the ORDER BY keyword. If no order by clause is needed, leave this blank.",
|
||||
] = None,
|
||||
with_clause: Annotated[
|
||||
str | None,
|
||||
"This is the part of the SQL query that comes after the WITH keyword when basing the query on a virtual table. If no WITH clause is needed, leave this blank.",
|
||||
] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
You have a connection to a postgres database.
|
||||
Execute a SELECT query and return the results against the postgres database. No other queries (INSERT, UPDATE, DELETE, etc.) are allowed.
|
||||
|
||||
ONLY use this tool if you have already loaded the schema of the tables you need to query. Use the <GetTableSchema> tool to load the schema if not already known.
|
||||
|
||||
The final query will be constructed as follows:
|
||||
SELECT {select_query_part} FROM {from_clause} JOIN {join_clause} WHERE {where_clause} HAVING {having_clause} ORDER BY {order_by_clause} LIMIT {limit} OFFSET {offset}
|
||||
|
||||
When running queries, follow these rules which will help avoid errors:
|
||||
* Never "select *" from a table. Always select the columns you need.
|
||||
* Always order your results by the most important columns first. If you aren't sure, order by the primary key.
|
||||
* Always use case-insensitive queries to match strings in the query.
|
||||
* Always trim strings in the query.
|
||||
* Prefer LIKE queries over direct string matches or regex queries.
|
||||
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
|
||||
"""
|
||||
async with await DatabaseEngine.get_engine(
|
||||
context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING")
|
||||
) as engine:
|
||||
try:
|
||||
return await _execute_query(
|
||||
engine,
|
||||
select_clause=select_clause,
|
||||
from_clause=from_clause,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
join_clause=join_clause,
|
||||
where_clause=where_clause,
|
||||
having_clause=having_clause,
|
||||
group_by_clause=group_by_clause,
|
||||
order_by_clause=order_by_clause,
|
||||
with_clause=with_clause,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RetryableToolError(
|
||||
f"Query failed: {e}",
|
||||
developer_message=f"Query failed with parameters: select_clause={select_clause}, from_clause={from_clause}, limit={limit}, offset={offset}, join_clause={join_clause}, where_clause={where_clause}, having_clause={having_clause}, order_by_clause={order_by_clause}, with_clause={with_clause}.",
|
||||
additional_prompt_content="Load the database schema <GetTableSchema> or use the <DiscoverTables> tool to discover the tables and try again.",
|
||||
retry_after_ms=10,
|
||||
) from e
|
||||
|
||||
|
||||
async def _get_schemas(engine: AsyncEngine) -> list[str]:
|
||||
"""Get all the schemas in the database"""
|
||||
async with engine.connect() as conn:
|
||||
|
||||
def get_schema_names(sync_conn: Any) -> list[str]:
|
||||
return list(inspect(sync_conn).get_schema_names())
|
||||
|
||||
schemas: list[str] = await conn.run_sync(get_schema_names)
|
||||
schemas = [schema for schema in schemas if schema != "information_schema"]
|
||||
|
||||
return schemas
|
||||
|
||||
|
||||
async def _get_tables(engine: AsyncEngine, schema_name: str) -> list[str]:
|
||||
"""Get all the tables in the database"""
|
||||
async with engine.connect() as conn:
|
||||
|
||||
def get_schema_names(sync_conn: Any) -> list[str]:
|
||||
return list(inspect(sync_conn).get_schema_names())
|
||||
|
||||
schemas: list[str] = await conn.run_sync(get_schema_names)
|
||||
tables = []
|
||||
for schema in schemas:
|
||||
if schema == schema_name:
|
||||
|
||||
def get_table_names(sync_conn: Any, s: str = schema) -> list[str]:
|
||||
return list(inspect(sync_conn).get_table_names(schema=s))
|
||||
|
||||
these_tables = await conn.run_sync(get_table_names)
|
||||
tables.extend(these_tables)
|
||||
|
||||
tables.sort()
|
||||
return tables
|
||||
|
||||
|
||||
async def _get_table_schema(engine: AsyncEngine, schema_name: str, table_name: str) -> list[str]:
|
||||
"""Get the schema of a table"""
|
||||
async with engine.connect() as connection:
|
||||
|
||||
def get_columns(sync_conn: Any, t: str = table_name, s: str = schema_name) -> list[Any]:
|
||||
return list(inspect(sync_conn).get_columns(t, s))
|
||||
|
||||
columns_table = await connection.run_sync(get_columns)
|
||||
|
||||
# Get primary key information
|
||||
pk_constraint = await connection.run_sync(
|
||||
lambda sync_conn: inspect(sync_conn).get_pk_constraint(table_name, schema_name)
|
||||
)
|
||||
primary_keys = set(pk_constraint.get("constrained_columns", []))
|
||||
|
||||
# Get index information
|
||||
indexes = await connection.run_sync(
|
||||
lambda sync_conn: inspect(sync_conn).get_indexes(table_name, schema_name)
|
||||
)
|
||||
indexed_columns = set()
|
||||
for index in indexes:
|
||||
indexed_columns.update(index.get("column_names", []))
|
||||
|
||||
results = []
|
||||
for column in columns_table:
|
||||
column_name = column["name"]
|
||||
column_type = column["type"].python_type.__name__
|
||||
|
||||
# Build column description
|
||||
description = f"{column_name}: {column_type}"
|
||||
|
||||
# Add primary key indicator
|
||||
if column_name in primary_keys:
|
||||
description += " (PRIMARY KEY)"
|
||||
|
||||
# Add index indicator
|
||||
if column_name in indexed_columns:
|
||||
description += " (INDEXED)"
|
||||
|
||||
results.append(description)
|
||||
|
||||
return results[:MAX_ROWS_RETURNED]
|
||||
|
||||
|
||||
async def _execute_query(
|
||||
engine: AsyncEngine,
|
||||
select_clause: str,
|
||||
from_clause: str,
|
||||
limit: int,
|
||||
offset: int,
|
||||
join_clause: str | None,
|
||||
where_clause: str | None,
|
||||
having_clause: str | None,
|
||||
group_by_clause: str | None,
|
||||
order_by_clause: str | None,
|
||||
with_clause: str | None,
|
||||
) -> list[str]:
|
||||
"""Execute a query and return the results."""
|
||||
async with engine.connect() as connection:
|
||||
query, parameters = DatabaseEngine.sanitize_query(
|
||||
select_clause=select_clause,
|
||||
from_clause=from_clause,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
join_clause=join_clause,
|
||||
where_clause=where_clause,
|
||||
having_clause=having_clause,
|
||||
group_by_clause=group_by_clause,
|
||||
order_by_clause=order_by_clause,
|
||||
with_clause=with_clause,
|
||||
)
|
||||
print(f"Query: {query}")
|
||||
print(f"Parameters: {parameters}")
|
||||
result = await connection.execute(text(query), parameters)
|
||||
rows = result.fetchall()
|
||||
results = [str(row) for row in rows]
|
||||
return results[:MAX_ROWS_RETURNED]
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
import arcade_postgres
|
||||
from arcade_core import ToolCatalog
|
||||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_postgres.tools.postgres import (
|
||||
discover_tables,
|
||||
execute_query,
|
||||
get_table_schema,
|
||||
)
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.85,
|
||||
warn_threshold=0.95,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_module(arcade_postgres)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def sql_eval_suite() -> EvalSuite:
|
||||
suite = EvalSuite(
|
||||
name="sql Tools Evaluation",
|
||||
system_message=(
|
||||
"You are an AI assistant with access to sql tools. "
|
||||
"Use them to help the user with their tasks."
|
||||
),
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get user by id (schema known)",
|
||||
user_message="Tell me the name and email of user #1 in my database. The table 'users' has the following schema: id: int, name: str, email: str, password_hash: str, created_at: datetime, updated_at: datetime",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=execute_query, args={"query": "SELECT name, email FROM users WHERE id = 1"}
|
||||
)
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[SimilarityCritic(critic_field="query", weight=1.0)],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Discover tables",
|
||||
user_message="What tables are in my database?",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(func=discover_tables, args={}),
|
||||
],
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get table schema (schema provided)",
|
||||
user_message="What columns are in the table 'public.users' in my database?",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="schema_name", weight=0.5),
|
||||
BinaryCritic(critic_field="table_name", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get table schema (schema not provided)",
|
||||
user_message="What columns are in the table 'users' in my database?",
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "When not provided, the schema is 'public'."}
|
||||
],
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
|
||||
),
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
BinaryCritic(critic_field="schema_name", weight=0.5),
|
||||
BinaryCritic(critic_field="table_name", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
[build-system]
|
||||
requires = [ "hatchling",]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arcade_postgres"
|
||||
version = "0.5.0"
|
||||
description = "Tools to query and explore a postgres database"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"arcade-mcp-server>=1.17.0,<2.0.0",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"pydantic>=2.11.7",
|
||||
"sqlalchemy>=2.0.41",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"asyncpg>=0.30.0",
|
||||
"greenlet>=3.2.3",
|
||||
]
|
||||
[[project.authors]]
|
||||
name = "evantahler"
|
||||
email = "support@arcade.dev"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"arcade-mcp[all]>=1.2.0,<2.0.0",
|
||||
"pytest>=8.3.0,<8.4.0",
|
||||
"pytest-cov>=4.0.0,<4.1.0",
|
||||
"pytest-mock>=3.11.1,<3.12.0",
|
||||
"pytest-asyncio>=0.24.0,<0.25.0",
|
||||
"mypy>=1.5.1,<1.6.0",
|
||||
"pre-commit>=3.4.0,<3.5.0",
|
||||
"tox>=4.11.1,<4.12.0",
|
||||
"ruff>=0.7.4,<0.8.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
arcade-postgres = "arcade_postgres.__main__:main"
|
||||
arcade_postgres = "arcade_postgres.__main__:main"
|
||||
|
||||
# Use local path sources for arcade libs when working locally
|
||||
[tool.uv.sources]
|
||||
arcade-mcp = { path = "../../", editable = true }
|
||||
arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true }
|
||||
|
||||
[tool.mypy]
|
||||
files = [ "arcade_postgres/**/*.py",]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
check_untyped_defs = "True"
|
||||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = [ "tests",]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = [ "arcade_postgres",]
|
||||
|
|
@ -1,399 +0,0 @@
|
|||
DROP TABLE IF EXISTS "public"."messages";
|
||||
-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup.
|
||||
-- Sequence and defined type
|
||||
CREATE SEQUENCE IF NOT EXISTS messages_id_seq;
|
||||
-- Table Definition
|
||||
CREATE TABLE "public"."messages" (
|
||||
"id" int4 NOT NULL DEFAULT nextval('messages_id_seq'::regclass),
|
||||
"body" text NOT NULL,
|
||||
"user_id" int4 NOT NULL,
|
||||
"created_at" timestamp NOT NULL DEFAULT now(),
|
||||
"updated_at" timestamp NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY ("id")
|
||||
);
|
||||
DROP TABLE IF EXISTS "public"."users";
|
||||
-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup.
|
||||
-- Sequence and defined type
|
||||
CREATE SEQUENCE IF NOT EXISTS users_id_seq;
|
||||
-- Table Definition
|
||||
CREATE TABLE "public"."users" (
|
||||
"id" int4 NOT NULL DEFAULT nextval('users_id_seq'::regclass),
|
||||
"name" varchar(256) NOT NULL,
|
||||
"email" text NOT NULL,
|
||||
"password_hash" text NOT NULL,
|
||||
"created_at" timestamp NOT NULL DEFAULT now(),
|
||||
"updated_at" timestamp NOT NULL DEFAULT now(),
|
||||
"status" varchar,
|
||||
PRIMARY KEY ("id")
|
||||
);
|
||||
INSERT INTO "public"."messages" (
|
||||
"id",
|
||||
"body",
|
||||
"user_id",
|
||||
"created_at",
|
||||
"updated_at"
|
||||
)
|
||||
VALUES -- User 1 (Alice) - 3 messages
|
||||
(
|
||||
1,
|
||||
'Hello everyone!',
|
||||
1,
|
||||
'2025-01-10 10:00:00.000000',
|
||||
'2025-01-10 10:00:00.000000'
|
||||
),
|
||||
(
|
||||
2,
|
||||
'How is everyone doing today?',
|
||||
1,
|
||||
'2025-01-10 11:30:00.000000',
|
||||
'2025-01-10 11:30:00.000000'
|
||||
),
|
||||
(
|
||||
3,
|
||||
'Great to see you all here!',
|
||||
1,
|
||||
'2025-01-10 14:15:00.000000',
|
||||
'2025-01-10 14:15:00.000000'
|
||||
),
|
||||
-- User 2 (Bob) - 2 messages
|
||||
(
|
||||
4,
|
||||
'Hi Alice! Doing well, thanks for asking.',
|
||||
2,
|
||||
'2025-01-10 11:35:00.000000',
|
||||
'2025-01-10 11:35:00.000000'
|
||||
),
|
||||
(
|
||||
5,
|
||||
'Anyone up for a game later?',
|
||||
2,
|
||||
'2025-01-10 16:20:00.000000',
|
||||
'2025-01-10 16:20:00.000000'
|
||||
),
|
||||
-- User 3 (Charlie) - 3 messages
|
||||
(
|
||||
6,
|
||||
'Count me in for the game!',
|
||||
3,
|
||||
'2025-01-10 16:25:00.000000',
|
||||
'2025-01-10 16:25:00.000000'
|
||||
),
|
||||
(
|
||||
7,
|
||||
'What time works for everyone?',
|
||||
3,
|
||||
'2025-01-10 16:30:00.000000',
|
||||
'2025-01-10 16:30:00.000000'
|
||||
),
|
||||
(
|
||||
8,
|
||||
'I can play around 8 PM',
|
||||
3,
|
||||
'2025-01-10 17:00:00.000000',
|
||||
'2025-01-10 17:00:00.000000'
|
||||
),
|
||||
-- User 4 (Diana) - 2 messages
|
||||
(
|
||||
9,
|
||||
'8 PM works for me too!',
|
||||
4,
|
||||
'2025-01-10 17:05:00.000000',
|
||||
'2025-01-10 17:05:00.000000'
|
||||
),
|
||||
(
|
||||
10,
|
||||
'What game should we play?',
|
||||
4,
|
||||
'2025-01-10 17:10:00.000000',
|
||||
'2025-01-10 17:10:00.000000'
|
||||
),
|
||||
-- User 5 (Evan) - 3 messages
|
||||
(
|
||||
11,
|
||||
'I suggest we try the new arcade game!',
|
||||
5,
|
||||
'2025-01-10 17:15:00.000000',
|
||||
'2025-01-10 17:15:00.000000'
|
||||
),
|
||||
(
|
||||
12,
|
||||
'It has great multiplayer features',
|
||||
5,
|
||||
'2025-01-10 17:20:00.000000',
|
||||
'2025-01-10 17:20:00.000000'
|
||||
),
|
||||
(
|
||||
13,
|
||||
'Perfect timing for a weekend session',
|
||||
5,
|
||||
'2025-01-10 18:00:00.000000',
|
||||
'2025-01-10 18:00:00.000000'
|
||||
),
|
||||
-- User 6 (Fiona) - 2 messages
|
||||
(
|
||||
14,
|
||||
'Sounds like fun! I love arcade games.',
|
||||
6,
|
||||
'2025-01-10 18:05:00.000000',
|
||||
'2025-01-10 18:05:00.000000'
|
||||
),
|
||||
(
|
||||
15,
|
||||
'Should I bring snacks?',
|
||||
6,
|
||||
'2025-01-10 18:10:00.000000',
|
||||
'2025-01-10 18:10:00.000000'
|
||||
),
|
||||
-- User 7 (George) - 3 messages
|
||||
(
|
||||
16,
|
||||
'Snacks are always welcome!',
|
||||
7,
|
||||
'2025-01-10 18:15:00.000000',
|
||||
'2025-01-10 18:15:00.000000'
|
||||
),
|
||||
(
|
||||
17,
|
||||
'I can bring some drinks',
|
||||
7,
|
||||
'2025-01-10 18:20:00.000000',
|
||||
'2025-01-10 18:20:00.000000'
|
||||
),
|
||||
(
|
||||
18,
|
||||
'This is going to be awesome',
|
||||
7,
|
||||
'2025-01-10 19:00:00.000000',
|
||||
'2025-01-10 19:00:00.000000'
|
||||
),
|
||||
-- User 8 (Helen) - 2 messages
|
||||
(
|
||||
19,
|
||||
'I agree! Cannot wait for the game night.',
|
||||
8,
|
||||
'2025-01-10 19:05:00.000000',
|
||||
'2025-01-10 19:05:00.000000'
|
||||
),
|
||||
(
|
||||
20,
|
||||
'Should we set up a Discord call?',
|
||||
8,
|
||||
'2025-01-10 19:10:00.000000',
|
||||
'2025-01-10 19:10:00.000000'
|
||||
),
|
||||
-- User 9 (Ian) - 3 messages
|
||||
(
|
||||
21,
|
||||
'Discord would be perfect for voice chat',
|
||||
9,
|
||||
'2025-01-10 19:15:00.000000',
|
||||
'2025-01-10 19:15:00.000000'
|
||||
),
|
||||
(
|
||||
22,
|
||||
'I will create a server for us',
|
||||
9,
|
||||
'2025-01-10 19:20:00.000000',
|
||||
'2025-01-10 19:20:00.000000'
|
||||
),
|
||||
(
|
||||
23,
|
||||
'Link will be shared in a few minutes',
|
||||
9,
|
||||
'2025-01-10 19:25:00.000000',
|
||||
'2025-01-10 19:25:00.000000'
|
||||
),
|
||||
-- User 10 (Julia) - 2 messages
|
||||
(
|
||||
24,
|
||||
'Thanks Ian! You are the best.',
|
||||
10,
|
||||
'2025-01-10 19:30:00.000000',
|
||||
'2025-01-10 19:30:00.000000'
|
||||
),
|
||||
(
|
||||
25,
|
||||
'See you all at 8 PM!',
|
||||
10,
|
||||
'2025-01-10 19:35:00.000000',
|
||||
'2025-01-10 19:35:00.000000'
|
||||
),
|
||||
-- Additional messages for Evan (user_id 5) - 10 more messages
|
||||
(
|
||||
26,
|
||||
'Just finished setting up the game server!',
|
||||
5,
|
||||
'2025-01-10 20:00:00.000000',
|
||||
'2025-01-10 20:00:00.000000'
|
||||
),
|
||||
(
|
||||
27,
|
||||
'Everyone should be able to connect now',
|
||||
5,
|
||||
'2025-01-10 20:05:00.000000',
|
||||
'2025-01-10 20:05:00.000000'
|
||||
),
|
||||
(
|
||||
28,
|
||||
'I added some custom maps too',
|
||||
5,
|
||||
'2025-01-10 20:10:00.000000',
|
||||
'2025-01-10 20:10:00.000000'
|
||||
),
|
||||
(
|
||||
29,
|
||||
'The graphics look amazing on this new version',
|
||||
5,
|
||||
'2025-01-10 20:15:00.000000',
|
||||
'2025-01-10 20:15:00.000000'
|
||||
),
|
||||
(
|
||||
30,
|
||||
'Hope you all enjoy the new features',
|
||||
5,
|
||||
'2025-01-10 20:20:00.000000',
|
||||
'2025-01-10 20:20:00.000000'
|
||||
),
|
||||
(
|
||||
31,
|
||||
'I also set up a leaderboard system',
|
||||
5,
|
||||
'2025-01-10 20:25:00.000000',
|
||||
'2025-01-10 20:25:00.000000'
|
||||
),
|
||||
(
|
||||
32,
|
||||
'We can track high scores now',
|
||||
5,
|
||||
'2025-01-10 20:30:00.000000',
|
||||
'2025-01-10 20:30:00.000000'
|
||||
),
|
||||
(
|
||||
33,
|
||||
'The game supports up to 8 players simultaneously',
|
||||
5,
|
||||
'2025-01-10 20:35:00.000000',
|
||||
'2025-01-10 20:35:00.000000'
|
||||
),
|
||||
(
|
||||
34,
|
||||
'I tested it earlier and it runs smoothly',
|
||||
5,
|
||||
'2025-01-10 20:40:00.000000',
|
||||
'2025-01-10 20:40:00.000000'
|
||||
),
|
||||
(
|
||||
35,
|
||||
'Cannot wait to see everyone online tonight!',
|
||||
5,
|
||||
'2025-01-10 20:45:00.000000',
|
||||
'2025-01-10 20:45:00.000000'
|
||||
);
|
||||
INSERT INTO "public"."users" (
|
||||
"id",
|
||||
"name",
|
||||
"email",
|
||||
"password_hash",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"status"
|
||||
)
|
||||
VALUES (
|
||||
1,
|
||||
'Alice',
|
||||
'alice@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
|
||||
'2024-09-01 20:49:38.759432',
|
||||
'2024-09-02 03:49:39.927',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
2,
|
||||
'Bob',
|
||||
'bob@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
|
||||
'2024-09-02 17:49:23.377425',
|
||||
'2024-09-02 17:49:23.377425',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
3,
|
||||
'Charlie',
|
||||
'charlie@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
|
||||
'2024-09-03 10:30:15.123456',
|
||||
'2024-09-03 10:30:15.123456',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
4,
|
||||
'Diana',
|
||||
'diana@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123',
|
||||
'2024-09-04 14:20:30.654321',
|
||||
'2024-09-04 14:20:30.654321',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
5,
|
||||
'Evan',
|
||||
'evan@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456',
|
||||
'2024-09-05 09:15:45.987654',
|
||||
'2024-09-05 09:15:45.987654',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
6,
|
||||
'Fiona',
|
||||
'fiona@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789',
|
||||
'2024-09-06 16:45:12.345678',
|
||||
'2024-09-06 16:45:12.345678',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
7,
|
||||
'George',
|
||||
'george@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012',
|
||||
'2024-09-07 11:30:25.876543',
|
||||
'2024-09-07 11:30:25.876543',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
8,
|
||||
'Helen',
|
||||
'helen@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345',
|
||||
'2024-09-08 13:25:40.234567',
|
||||
'2024-09-08 13:25:40.234567',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
9,
|
||||
'Ian',
|
||||
'ian@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678',
|
||||
'2024-09-09 08:40:55.765432',
|
||||
'2024-09-09 08:40:55.765432',
|
||||
'active'
|
||||
),
|
||||
(
|
||||
10,
|
||||
'Julia',
|
||||
'julia@example.com',
|
||||
'$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901',
|
||||
'2024-09-10 15:55:18.123456',
|
||||
'2024-09-10 15:55:18.123456',
|
||||
'active'
|
||||
);
|
||||
ALTER TABLE "public"."messages"
|
||||
ADD FOREIGN KEY ("user_id") REFERENCES "public"."users"("id");
|
||||
-- set pk to 11
|
||||
ALTER SEQUENCE users_id_seq RESTART WITH 11;
|
||||
-- Indices
|
||||
CREATE UNIQUE INDEX name_idx ON public.users USING btree (name);
|
||||
CREATE UNIQUE INDEX email_idx ON public.users USING btree (email);
|
||||
DROP INDEX IF EXISTS users_email_unique;
|
||||
CREATE UNIQUE INDEX users_email_unique ON public.users USING btree (email);
|
||||
|
|
@ -1,188 +0,0 @@
|
|||
import os
|
||||
from os import environ
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from arcade_mcp_server import Context
|
||||
from arcade_mcp_server.exceptions import RetryableToolError
|
||||
from arcade_postgres.tools.postgres import (
|
||||
DatabaseEngine,
|
||||
discover_schemas,
|
||||
discover_tables,
|
||||
execute_select_query,
|
||||
get_table_schema,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
POSTGRES_DATABASE_CONNECTION_STRING = (
|
||||
environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING")
|
||||
or "postgresql://postgres@localhost:5432/postgres"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = MagicMock(spec=Context)
|
||||
context.get_secret = MagicMock(return_value=POSTGRES_DATABASE_CONNECTION_STRING)
|
||||
return context
|
||||
|
||||
|
||||
# before the tests, restore the database from the dump
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def restore_database():
|
||||
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
|
||||
engine = create_async_engine(
|
||||
POSTGRES_DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split(
|
||||
"?"
|
||||
)[0]
|
||||
)
|
||||
async with engine.connect() as c:
|
||||
queries = f.read().split(";")
|
||||
await c.execute(text("BEGIN"))
|
||||
for query in queries:
|
||||
if query.strip():
|
||||
await c.execute(text(query))
|
||||
await c.commit()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def cleanup_engines():
|
||||
"""Clean up database engines after each test to prevent connection leaks."""
|
||||
yield
|
||||
# Clean up all cached engines after each test
|
||||
await DatabaseEngine.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_schemas(mock_context) -> None:
|
||||
assert await discover_schemas(mock_context) == ["public"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_tables(mock_context) -> None:
|
||||
assert await discover_tables(mock_context) == ["messages", "users"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema(mock_context) -> None:
|
||||
assert await get_table_schema(mock_context, "public", "users") == [
|
||||
"id: int (PRIMARY KEY)",
|
||||
"name: str (INDEXED)",
|
||||
"email: str (INDEXED)",
|
||||
"password_hash: str",
|
||||
"created_at: datetime",
|
||||
"updated_at: datetime",
|
||||
"status: str",
|
||||
]
|
||||
|
||||
assert await get_table_schema(mock_context, "public", "messages") == [
|
||||
"id: int (PRIMARY KEY)",
|
||||
"body: str",
|
||||
"user_id: int",
|
||||
"created_at: datetime",
|
||||
"updated_at: datetime",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query(mock_context) -> None:
|
||||
assert await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="id, name, email",
|
||||
from_clause="users",
|
||||
where_clause="id = 1",
|
||||
) == [
|
||||
"(1, 'Alice', 'alice@example.com')",
|
||||
]
|
||||
assert await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="id, name, email",
|
||||
from_clause="users",
|
||||
order_by_clause="id",
|
||||
limit=1,
|
||||
offset=1,
|
||||
) == [
|
||||
"(2, 'Bob', 'bob@example.com')",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_keywords(mock_context) -> None:
|
||||
assert await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="SELECT id, name, email",
|
||||
from_clause="FROM users",
|
||||
limit=1,
|
||||
) == [
|
||||
"(1, 'Alice', 'alice@example.com')",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_join(mock_context) -> None:
|
||||
assert await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="u.id, u.name, u.email, m.id, m.body",
|
||||
from_clause="users u",
|
||||
join_clause="messages m ON u.id = m.user_id",
|
||||
limit=1,
|
||||
) == [
|
||||
"(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_group_by(mock_context) -> None:
|
||||
assert await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="u.name, COUNT(m.id) AS message_count",
|
||||
from_clause="messages m",
|
||||
join_clause="users u ON m.user_id = u.id",
|
||||
group_by_clause="u.name",
|
||||
order_by_clause="message_count DESC",
|
||||
limit=2,
|
||||
) == [
|
||||
"('Evan', 13)",
|
||||
"('Alice', 3)",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_no_results(mock_context) -> None:
|
||||
# does not raise an error
|
||||
assert (
|
||||
await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="id, name, email",
|
||||
from_clause="users",
|
||||
where_clause="id = 9999999999",
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_with_problem(mock_context) -> None:
|
||||
# 'foo' is not a valid id
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="*",
|
||||
from_clause="users",
|
||||
where_clause="id = 'foo'",
|
||||
)
|
||||
assert "Do not use * in the select clause" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_select_query_rejects_non_select(mock_context) -> None:
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await execute_select_query(
|
||||
mock_context,
|
||||
select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
|
||||
from_clause="users",
|
||||
)
|
||||
assert "Only SELECT queries are allowed" in str(e.value)
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Run PostgreSQL container
|
||||
docker run -d --name some-postgres-server -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres:latest
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
echo "Waiting for PostgreSQL to be ready..."
|
||||
for i in {1..30}; do
|
||||
if docker exec some-postgres-server pg_isready -U postgres > /dev/null 2>&1; then
|
||||
echo "PostgreSQL is ready!"
|
||||
break
|
||||
fi
|
||||
echo "Waiting... ($i/30)"
|
||||
sleep 1
|
||||
done
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
files: ^.*/zendesk/.*
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: "v4.4.0"
|
||||
hooks:
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
target-version = "py310"
|
||||
line-length = 100
|
||||
fix = true
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
# flake8-2020
|
||||
"YTT",
|
||||
# flake8-bandit
|
||||
"S",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-builtins
|
||||
"A",
|
||||
# flake8-comprehensions
|
||||
"C4",
|
||||
# flake8-debugger
|
||||
"T10",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# mccabe
|
||||
"C90",
|
||||
# pycodestyle
|
||||
"E", "W",
|
||||
# pyflakes
|
||||
"F",
|
||||
# pygrep-hooks
|
||||
"PGH",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# ruff
|
||||
"RUF",
|
||||
# tryceratops
|
||||
"TRY",
|
||||
]
|
||||
|
||||
ignore = ["C901"]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"**/tests/*" = ["S101"]
|
||||
|
||||
[format]
|
||||
preview = true
|
||||
skip-magic-trailing-comma = false
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ github Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the uv environment and install all packages with dependencies
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras --no-sources
|
||||
@if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: install-local
|
||||
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
||||
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
||||
@uv sync --active --all-extras
|
||||
@if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi
|
||||
@echo "✅ All packages and dependencies installed via uv"
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
uv build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
@uv run --no-sources coverage report
|
||||
@echo "Generating coverage report"
|
||||
@uv run --no-sources coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
uv version --no-sources --bump patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@if [ -f .pre-commit-config.yaml ]; then\
|
||||
echo "🚀 Linting code: Running pre-commit";\
|
||||
uv run --no-sources pre-commit run -a;\
|
||||
fi
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@uv run --no-sources mypy --config-file=pyproject.toml
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
from arcade_zendesk.tools import (
|
||||
add_ticket_comment,
|
||||
get_ticket_comments,
|
||||
list_tickets,
|
||||
mark_ticket_solved,
|
||||
search_articles,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"add_ticket_comment",
|
||||
"get_ticket_comments",
|
||||
"list_tickets",
|
||||
"mark_ticket_solved",
|
||||
"search_articles",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue