diff --git a/toolkits/stripe/.pre-commit-config.yaml b/toolkits/stripe/.pre-commit-config.yaml new file mode 100644 index 00000000..3953e996 --- /dev/null +++ b/toolkits/stripe/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^./ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/stripe/.ruff.toml b/toolkits/stripe/.ruff.toml new file mode 100644 index 00000000..9519fe6c --- /dev/null +++ b/toolkits/stripe/.ruff.toml @@ -0,0 +1,44 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"**/tests/*" = ["S101"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/stripe/LICENSE b/toolkits/stripe/LICENSE new file mode 100644 index 00000000..8c2d4f37 --- /dev/null +++ b/toolkits/stripe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/stripe/Makefile b/toolkits/stripe/Makefile new file mode 100644 index 00000000..c74cb193 --- /dev/null +++ b/toolkits/stripe/Makefile @@ -0,0 +1,58 @@ +.PHONY: help + +help: + @echo "🛠️ stripe Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the poetry environment and install the pre-commit hooks + @echo "📦 Checking if Poetry is installed" + @if ! command -v poetry > /dev/null 2>&1; then \ + echo "📦 Poetry not found. Checking if pip is available"; \ + if ! command -v pip >/dev/null 2>&1; then \ + echo "❌ pip is not installed. Please install pip first."; \ + exit 1; \ + fi; \ + echo "📦 Installing Poetry with pip"; \ + pip install poetry==1.8.5; \ + else \ + echo "📦 Poetry is already installed"; \ + fi + @echo "🚀 Installing package in development mode with all extras" + poetry install --all-extras + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + poetry build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + coverage report + @echo "Generating coverage report" + coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file + @echo "🚀 Bumping version in pyproject.toml" + poetry version patch + +.PHONY: check +check: ## Run code quality tools. + @echo "🚀 Checking Poetry lock file consistency with 'pyproject.toml': Running poetry check" + @poetry check + @echo "🚀 Linting code: Running pre-commit" + @poetry run pre-commit run -a + @echo "🚀 Static type checking: Running mypy" + @poetry run mypy --config-file=pyproject.toml diff --git a/toolkits/stripe/_generate.py b/toolkits/stripe/_generate.py new file mode 100644 index 00000000..cc1d3311 --- /dev/null +++ b/toolkits/stripe/_generate.py @@ -0,0 +1,110 @@ +import logging +from pathlib import Path +from typing import Union, get_args + +from stripe_agent_toolkit.functions import * # noqa: F403 +from stripe_agent_toolkit.prompts import * # noqa: F403 +from stripe_agent_toolkit.schema import * # noqa: F403 +from stripe_agent_toolkit.tools import tools + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_type_str(arg_type): + """Extract type name, handling Optional/Union types.""" + if hasattr(arg_type, "__origin__") and arg_type.__origin__ is Union: + non_none = [a for a in get_args(arg_type) if a is not type(None)] + if len(non_none) == 1: + return non_none[0].__name__ + return arg_type.__name__ if hasattr(arg_type, "__name__") else str(arg_type) + + +def generate_stripe_tools( + output_file: Path = Path("arcade_stripe") / "tools" / "stripe.py", +) -> None: + """ + Generate the Arcade AI Stripe Toolkit file from the stripe agent toolkit definitions. + """ + logger.info("Generating stripe tools file at %s", output_file) + try: + output_file.touch(exist_ok=True) + with output_file.open("w") as f: + f.write("""import os +from typing import Annotated, Optional + +from stripe_agent_toolkit.api import StripeAPI + +from arcade.sdk import ToolContext, tool + +def run_stripe_tool(context: ToolContext, method_name: str, params: dict) -> str: + \"\"\" + Helper function that retrieves the Stripe secret key, initializes the API, + and executes the specified method with the provided parameters. + \"\"\" + api_key = context.get_secret("STRIPE_SECRET_KEY") + stripe_api = StripeAPI(secret_key=api_key, context=None) + params = {k: v for k, v in params.items() if v is not None} + return stripe_api.run(method_name, **params) # type: ignore[no-any-return] + +""") + # Generate each tool function from the stripe agent toolkit + for tool_info in tools: + method_name = tool_info["method"] + method = globals().get(method_name) + if not method: + logger.warning("Method %s not found.", method_name) + continue + + args_schema = tool_info["args_schema"] + description = tool_info["description"].strip() + + arg_names = list(args_schema.__annotations__.keys()) + arg_types = [args_schema.__annotations__[field] for field in arg_names] + + params_list = [] + for name, arg_type in zip(arg_names, arg_types, strict=False): + field = args_schema.model_fields[name] + # Check if the type annotation already includes Optional (i.e. Union[..., None]) + is_optional_type = ( + hasattr(arg_type, "__origin__") + and arg_type.__origin__ is Union + and type(None) in get_args(arg_type) + ) + if field.is_required: + if is_optional_type: + params_list.append( + f"{name}: Annotated[{get_type_str(arg_type)} | None, " + f'"{field.description}"] = None' + ) + else: + params_list.append( + f"{name}: Annotated[{get_type_str(arg_type)}, " + f'"{field.description}"]' + ) + else: + default_repr = "None" if field.default is None else repr(field.default) + params_list.append( + f"{name}: Annotated[Optional[{get_type_str(arg_type)}], " + f'"{field.description}"] = {default_repr}' + ) + params_str = ", ".join(params_list) + dict_items = ", ".join([f'"{name}": {name}' for name in arg_names]) + arcade_tool_code = ( + f'@tool(requires_secrets=["STRIPE_SECRET_KEY"])\n' + f"def {method_name}(context: ToolContext, {params_str}) -> " + f'Annotated[str, "{description.splitlines()[0]}"]:\n' + f' """{description.splitlines()[0]}"""\n' + f' return run_stripe_tool(context, "{method_name}", ' + + "{" + + dict_items + + "})\n\n" + ) + f.write(arcade_tool_code) + except Exception: + logger.exception("An error occurred while generating stripe tools") + raise + + +if __name__ == "__main__": + generate_stripe_tools() diff --git a/toolkits/stripe/arcade_stripe/__init__.py b/toolkits/stripe/arcade_stripe/__init__.py new file mode 100644 index 00000000..d9fc19a6 --- /dev/null +++ b/toolkits/stripe/arcade_stripe/__init__.py @@ -0,0 +1,35 @@ +from arcade_stripe.tools.stripe import ( + create_billing_portal_session, + create_customer, + create_invoice, + create_invoice_item, + create_payment_link, + create_price, + create_product, + create_refund, + finalize_invoice, + list_customers, + list_invoices, + list_payment_intents, + list_prices, + list_products, + retrieve_balance, +) + +__all__ = [ + "create_billing_portal_session", + "create_customer", + "create_invoice", + "create_invoice_item", + "create_payment_link", + "create_price", + "create_product", + "create_refund", + "finalize_invoice", + "list_customers", + "list_invoices", + "list_payment_intents", + "list_prices", + "list_products", + "retrieve_balance", +] diff --git a/toolkits/stripe/arcade_stripe/evals/eval_stripe.py b/toolkits/stripe/arcade_stripe/evals/eval_stripe.py new file mode 100644 index 00000000..d78d1609 --- /dev/null +++ b/toolkits/stripe/arcade_stripe/evals/eval_stripe.py @@ -0,0 +1,319 @@ +from arcade.sdk import ToolCatalog +from arcade.sdk.eval import BinaryCritic, EvalRubric, EvalSuite, ExpectedToolCall, tool_eval +from arcade.sdk.eval.critic import SimilarityCritic + +import arcade_stripe +from arcade_stripe.tools.stripe import ( + create_billing_portal_session, + create_customer, + create_invoice, + create_invoice_item, + create_payment_link, + create_price, + create_product, + create_refund, + finalize_invoice, + list_customers, + list_invoices, + list_payment_intents, + list_prices, + list_products, + retrieve_balance, +) + +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_module(arcade_stripe) + + +@tool_eval() +def stripe_eval_suite() -> EvalSuite: + """Evaluation suite for Stripe Tools.""" + suite = EvalSuite( + name="Stripe Tools Evaluation Suite", + system_message=( + "You are an AI assistant that helps users " + "interact with Stripe using the provided tools." + ), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Create a customer", + user_message=( + "add 'Alice Jenner' to my customers. she has a gmail that is just her first name" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_customer, + args={"name": "Alice Jenner", "email": "alice@gmail.com"}, + ) + ], + critics=[ + BinaryCritic(critic_field="name", weight=0.5), + BinaryCritic(critic_field="email", weight=0.5), + ], + ) + + suite.add_case( + name="List customers with limit", + user_message="get 5 customers", + expected_tool_calls=[ + ExpectedToolCall( + func=list_customers, + args={ + "limit": 5, + "email": None, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="limit", weight=1.0), + ], + ) + + suite.add_case( + name="Create a product", + user_message=( + "Create a product named 'Pro Subscription' that provides: " + "- Higher rate limits" + "- Priority support" + "- Early access to new features" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_product, + args={ + "name": "Pro Subscription", + "description": ( + "Provides higher rate limits, priority support, " + "and early access to new features." + ), + }, + ) + ], + critics=[ + BinaryCritic(critic_field="name", weight=0.6), + SimilarityCritic( + critic_field="description", + weight=0.4, + similarity_threshold=0.75, + ), + ], + ) + + suite.add_case( + name="List products", + user_message="List 10 of my products.", + expected_tool_calls=[ + ExpectedToolCall( + func=list_products, + args={ + "limit": 10, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="limit", weight=1.0), + ], + ) + + suite.add_case( + name="Create a price", + user_message="Create a price of $1298.99 for product 'prod_ABC123' in us currency.", + expected_tool_calls=[ + ExpectedToolCall( + func=create_price, + args={ + "product": "prod_ABC123", + "unit_amount": 129899, + "currency": "usd", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="product", weight=0.4), + BinaryCritic(critic_field="unit_amount", weight=0.3), + SimilarityCritic( + critic_field="currency", + weight=0.3, + similarity_threshold=0.95, + ), + ], + ) + + suite.add_case( + name="Create a payment link", + user_message=( + "Joe needs a link to pay for my product. price is 'price_XYZ789'. create it please" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_payment_link, + args={ + "price": "price_XYZ789", + "quantity": 1, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="price", weight=0.5), + BinaryCritic(critic_field="quantity", weight=0.5), + ], + ) + + suite.add_case( + name="Retrieve balance", + user_message="How much money do i have", + expected_tool_calls=[ + ExpectedToolCall( + func=retrieve_balance, + args={}, + ) + ], + critics=[], + ) + + suite.add_case( + name="Create a refund", + user_message="Refund the payment intent 'pi_789XYZ' for 5 bucks.", + expected_tool_calls=[ + ExpectedToolCall( + func=create_refund, + args={ + "payment_intent": "pi_789XYZ", + "amount": 500, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="payment_intent", weight=0.5), + BinaryCritic(critic_field="amount", weight=0.5), + ], + ) + + suite.add_case( + name="Create billing portal session", + user_message="Create a billing portal session for customer 'cus_test123' with return URL 'https://example.com/return'.", + expected_tool_calls=[ + ExpectedToolCall( + func=create_billing_portal_session, + args={ + "customer": "cus_test123", + "return_url": "https://example.com/return", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="customer", weight=0.6), + BinaryCritic(critic_field="return_url", weight=0.4), + ], + ) + + suite.add_case( + name="List prices for a product", + user_message="what are the prices for my product 'prod_ABC123'", + expected_tool_calls=[ + ExpectedToolCall( + func=list_prices, + args={ + "product": "prod_ABC123", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="product", weight=1.0), + ], + ) + + suite.add_case( + name="List invoices for a customer", + user_message="get invoices for my customer 'cus_456def'", + expected_tool_calls=[ + ExpectedToolCall( + func=list_invoices, + args={ + "customer": "cus_456def", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="customer", weight=1.0), + ], + ) + + suite.add_case( + name="Create an invoice", + user_message="Create an invoice for my customer 'cus_456def' with 15 days until due.", + expected_tool_calls=[ + ExpectedToolCall( + func=create_invoice, + args={ + "customer": "cus_456def", + "days_until_due": 15, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="customer", weight=0.5), + BinaryCritic(critic_field="days_until_due", weight=0.5), + ], + ) + + suite.add_case( + name="Create an invoice item", + user_message=( + "Create an invoice item for my customer 'cus_456def' " + "for price 'price_789ghi' on invoice 'in_123test'." + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_invoice_item, + args={ + "customer": "cus_456def", + "price": "price_789ghi", + "invoice": "in_123test", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="customer", weight=0.33), + BinaryCritic(critic_field="price", weight=0.33), + BinaryCritic(critic_field="invoice", weight=0.34), + ], + ) + + suite.add_case( + name="Finalize an invoice", + user_message="Make 'in_123test' finalized.", + expected_tool_calls=[ + ExpectedToolCall( + func=finalize_invoice, + args={"invoice": "in_123test"}, + ) + ], + critics=[ + BinaryCritic(critic_field="invoice", weight=1.0), + ], + ) + + suite.add_case( + name="List payment intents for a customer", + user_message="get payment intents for my customer 'cus_456def'", + expected_tool_calls=[ + ExpectedToolCall( + func=list_payment_intents, + args={"customer": "cus_456def"}, + ) + ], + critics=[ + BinaryCritic(critic_field="customer", weight=1.0), + ], + ) + + return suite diff --git a/toolkits/stripe/arcade_stripe/tools/__init__.py b/toolkits/stripe/arcade_stripe/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/stripe/arcade_stripe/tools/stripe.py b/toolkits/stripe/arcade_stripe/tools/stripe.py new file mode 100644 index 00000000..b4f8ca10 --- /dev/null +++ b/toolkits/stripe/arcade_stripe/tools/stripe.py @@ -0,0 +1,196 @@ +from typing import Annotated + +from arcade.sdk import ToolContext, tool +from stripe_agent_toolkit.api import StripeAPI + + +def run_stripe_tool(context: ToolContext, method_name: str, params: dict) -> str: + """ + Helper function that retrieves the Stripe secret key, initializes the API, + and executes the specified method with the provided parameters. + """ + api_key = context.get_secret("STRIPE_SECRET_KEY") + stripe_api = StripeAPI(secret_key=api_key, context=None) + params = {k: v for k, v in params.items() if v is not None} + return stripe_api.run(method_name, **params) # type: ignore[no-any-return] + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def create_customer( + context: ToolContext, + name: Annotated[str, "The name of the customer."], + email: Annotated[str | None, "The email of the customer."] = None, +) -> Annotated[str, "This tool will create a customer in Stripe."]: + """This tool will create a customer in Stripe.""" + return run_stripe_tool(context, "create_customer", {"name": name, "email": email}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def list_customers( + context: ToolContext, + limit: Annotated[ + int | None, + "A limit on the number of objects to be returned. Limit can range between 1 and 100.", + ] = None, + email: Annotated[ + str | None, + "A case-sensitive filter on the list based on the customer's email field. " + "The value must be a string.", + ] = None, +) -> Annotated[str, "This tool will fetch a list of Customers from Stripe."]: + """This tool will fetch a list of Customers from Stripe.""" + return run_stripe_tool(context, "list_customers", {"limit": limit, "email": email}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def create_product( + context: ToolContext, + name: Annotated[str, "The name of the product."], + description: Annotated[str | None, "The description of the product."] = None, +) -> Annotated[str, "This tool will create a product in Stripe."]: + """This tool will create a product in Stripe.""" + return run_stripe_tool(context, "create_product", {"name": name, "description": description}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def list_products( + context: ToolContext, + limit: Annotated[ + int | None, + "A limit on the number of objects to be returned. Limit can range between 1 and 100, " + "and the default is 10.", + ] = None, +) -> Annotated[str, "This tool will fetch a list of Products from Stripe."]: + """This tool will fetch a list of Products from Stripe.""" + return run_stripe_tool(context, "list_products", {"limit": limit}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def create_price( + context: ToolContext, + product: Annotated[str, "The ID of the product to create the price for."], + unit_amount: Annotated[int, "The unit amount of the price in cents."], + currency: Annotated[str, "The currency of the price."], +) -> Annotated[str, "This tool will create a price in Stripe. If a product has not already been"]: + """This tool will create a price in Stripe. If a product has not already been""" + return run_stripe_tool( + context, + "create_price", + {"product": product, "unit_amount": unit_amount, "currency": currency}, + ) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def list_prices( + context: ToolContext, + product: Annotated[str | None, "The ID of the product to list prices for."] = None, + limit: Annotated[ + int | None, + "A limit on the number of objects to be returned. Limit can range between 1 and 100, " + "and the default is 10.", + ] = None, +) -> Annotated[str, "This tool will fetch a list of Prices from Stripe."]: + """This tool will fetch a list of Prices from Stripe.""" + return run_stripe_tool(context, "list_prices", {"product": product, "limit": limit}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def create_payment_link( + context: ToolContext, + price: Annotated[str, "The ID of the price to create the payment link for."], + quantity: Annotated[int, "The quantity of the product to include."], +) -> Annotated[str, "This tool will create a payment link in Stripe."]: + """This tool will create a payment link in Stripe.""" + return run_stripe_tool(context, "create_payment_link", {"price": price, "quantity": quantity}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def list_invoices( + context: ToolContext, + customer: Annotated[str | None, "The ID of the customer to list invoices for."] = None, + limit: Annotated[ + int | None, + "A limit on the number of objects to be returned. Limit can range between 1 and 100, " + "and the default is 10.", + ] = None, +) -> Annotated[str, "This tool will list invoices in Stripe."]: + """This tool will list invoices in Stripe.""" + return run_stripe_tool(context, "list_invoices", {"customer": customer, "limit": limit}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def create_invoice( + context: ToolContext, + customer: Annotated[str, "The ID of the customer to create the invoice for."], + days_until_due: Annotated[int | None, "The number of days until the invoice is due."] = None, +) -> Annotated[str, "This tool will create an invoice in Stripe."]: + """This tool will create an invoice in Stripe.""" + return run_stripe_tool( + context, "create_invoice", {"customer": customer, "days_until_due": days_until_due} + ) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def create_invoice_item( + context: ToolContext, + customer: Annotated[str, "The ID of the customer to create the invoice item for."], + price: Annotated[str, "The ID of the price for the item."], + invoice: Annotated[str, "The ID of the invoice to create the item for."], +) -> Annotated[str, "This tool will create an invoice item in Stripe."]: + """This tool will create an invoice item in Stripe.""" + return run_stripe_tool( + context, "create_invoice_item", {"customer": customer, "price": price, "invoice": invoice} + ) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def finalize_invoice( + context: ToolContext, invoice: Annotated[str, "The ID of the invoice to finalize."] +) -> Annotated[str, "This tool will finalize an invoice in Stripe."]: + """This tool will finalize an invoice in Stripe.""" + return run_stripe_tool(context, "finalize_invoice", {"invoice": invoice}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def retrieve_balance( + context: ToolContext, +) -> Annotated[str, "This tool will retrieve the balance from Stripe. It takes no input."]: + """This tool will retrieve the balance from Stripe. It takes no input.""" + return run_stripe_tool(context, "retrieve_balance", {}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def create_refund( + context: ToolContext, + payment_intent: Annotated[str, "The ID of the PaymentIntent to refund."], + amount: Annotated[int | None, "The amount to refund in cents."] = None, +) -> Annotated[str, "This tool will refund a payment intent in Stripe."]: + """This tool will refund a payment intent in Stripe.""" + return run_stripe_tool( + context, "create_refund", {"payment_intent": payment_intent, "amount": amount} + ) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def list_payment_intents( + context: ToolContext, + customer: Annotated[str | None, "The ID of the customer to list payment intents for."] = None, + limit: Annotated[ + int | None, + "A limit on the number of objects to be returned. Limit can range between 1 and 100.", + ] = None, +) -> Annotated[str, "This tool will list payment intents in Stripe."]: + """This tool will list payment intents in Stripe.""" + return run_stripe_tool(context, "list_payment_intents", {"customer": customer, "limit": limit}) + + +@tool(requires_secrets=["STRIPE_SECRET_KEY"]) +def create_billing_portal_session( + context: ToolContext, + customer: Annotated[str, "The ID of the customer to create the billing portal session for."], + return_url: Annotated[str | None, "The default URL to return to afterwards."] = None, +) -> Annotated[str, "This tool will create a billing portal session."]: + """This tool will create a billing portal session.""" + return run_stripe_tool( + context, "create_billing_portal_session", {"customer": customer, "return_url": return_url} + ) diff --git a/toolkits/stripe/pyproject.toml b/toolkits/stripe/pyproject.toml new file mode 100644 index 00000000..43d32cde --- /dev/null +++ b/toolkits/stripe/pyproject.toml @@ -0,0 +1,41 @@ +[tool.poetry] +name = "arcade_stripe" +version = "0.0.1" +description = "Arcade.dev LLM tools for Stripe" +authors = ["Arcade "] + +[tool.poetry.dependencies] +python = "^3.11" +arcade-ai = "^1.0.5" +stripe-agent-toolkit = "^0.6.1" +stripe = "^11.0.0" + +[tool.poetry.dev-dependencies] +pytest = "^8.3.0" +pytest-cov = "^4.0.0" +mypy = "^1.5.1" +pre-commit = "^3.4.0" +tox = "^4.11.1" +ruff = "^0.7.4" + +[build-system] +requires = ["poetry-core>=1.0.0,<2.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.mypy] +files = ["arcade_stripe/**/*.py"] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.coverage.report] +skip_empty = true diff --git a/toolkits/stripe/tests/__init__.py b/toolkits/stripe/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/stripe/tests/test_stripe.py b/toolkits/stripe/tests/test_stripe.py new file mode 100644 index 00000000..1eff476f --- /dev/null +++ b/toolkits/stripe/tests/test_stripe.py @@ -0,0 +1,87 @@ +import pytest + +from arcade_stripe.tools.stripe import ( + create_billing_portal_session, + create_customer, + create_invoice, + create_invoice_item, + create_payment_link, + create_price, + create_product, + create_refund, + finalize_invoice, + list_customers, + list_invoices, + list_payment_intents, + list_prices, + list_products, + retrieve_balance, +) + + +class DummyContext: + def get_secret(self, key: str): + return "test_secret_key" + + +class DummyStripeAPI: + def __init__(self, secret_key, context): + self.secret_key = secret_key + + def run(self, method_name, **params): + return {"method": method_name, "params": params} + + +@pytest.mark.parametrize( + ("current_tool", "params"), + [ + (create_customer, {"name": "John Doe"}), + (create_customer, {"name": "John Doe", "email": "john.doe@example.com"}), + (list_customers, {}), + (list_customers, {"limit": 10}), + (list_customers, {"email": "john.doe@example.com"}), + (list_customers, {"limit": 10, "email": "john.doe@example.com"}), + (create_product, {"name": "Product 1"}), + (create_product, {"name": "Product 1", "description": "Description 1"}), + (list_products, {}), + (list_products, {"limit": 10}), + (create_price, {"product": "product_123", "unit_amount": 1000, "currency": "usd"}), + (list_prices, {}), + (list_prices, {"product": "product_123"}), + (list_prices, {"limit": 10}), + (list_prices, {"product": "product_123", "limit": 10}), + (create_payment_link, {"price": "price_123", "quantity": 100}), + (list_invoices, {}), + (list_invoices, {"customer": "customer_123"}), + (list_invoices, {"limit": 10}), + (list_invoices, {"customer": "customer_123", "limit": 10}), + (create_invoice, {"customer": "customer_123"}), + (create_invoice, {"customer": "customer_123", "days_until_due": 30}), + ( + create_invoice_item, + {"customer": "customer_123", "price": "price_123", "invoice": "invoice_123"}, + ), + (finalize_invoice, {"invoice": "invoice_123"}), + (retrieve_balance, {}), + (create_refund, {"payment_intent": "payment_intent_123"}), + (create_refund, {"payment_intent": "payment_intent_123", "amount": 100}), + (list_payment_intents, {}), + (list_payment_intents, {"customer": "customer_123"}), + (list_payment_intents, {"limit": 10}), + (list_payment_intents, {"customer": "customer_123", "limit": 10}), + (create_billing_portal_session, {"customer": "customer_123"}), + ( + create_billing_portal_session, + {"customer": "customer_123", "return_url": "https://example.com"}, + ), + ], +) +def test_stripe_tools(monkeypatch, current_tool, params): + monkeypatch.setattr("arcade_stripe.tools.stripe.StripeAPI", DummyStripeAPI) + + context = DummyContext() + + result = current_tool(context, **params) + expected = {"method": current_tool.__name__, "params": params} + + assert result == expected