diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index bd2bee30..8af96d73 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -3,8 +3,10 @@ import os import threading import uuid import webbrowser +from pathlib import Path from typing import Any, Optional +import httpx import typer from arcadepy import Arcade from arcadepy.types import AuthorizationResponse @@ -14,6 +16,7 @@ from rich.markup import escape from rich.text import Text from tqdm import tqdm +import arcade.cli.worker as worker from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login from arcade.cli.constants import ( CREDENTIALS_FILE_PATH, @@ -30,7 +33,7 @@ from arcade.cli.launcher import start_servers from arcade.cli.show import show_logic from arcade.cli.utils import ( OrderCommands, - compute_engine_base_url, + compute_base_url, compute_login_url, get_eval_files, get_user_input, @@ -43,6 +46,8 @@ from arcade.cli.utils import ( validate_and_get_config, version_callback, ) +from arcade.cli.worker import parse_deployment_response +from arcade.worker.config.deployment import Deployment cli = typer.Typer( cls=OrderCommands, @@ -52,6 +57,9 @@ cli = typer.Typer( pretty_exceptions_show_locals=False, pretty_exceptions_short=True, ) + + +cli.add_typer(worker.app, name="worker", help="Manage workers") console = Console() @@ -225,7 +233,7 @@ def chat( ) config = validate_and_get_config() - base_url = compute_engine_base_url(force_tls, force_no_tls, host, port) + base_url = compute_base_url(force_tls, force_no_tls, host, port) client = Arcade(api_key=config.api.key, base_url=base_url) user_email = config.user.email if config.user else None @@ -352,7 +360,7 @@ def evals( config = validate_and_get_config() host = PROD_ENGINE_HOST if cloud else host - base_url = compute_engine_base_url(force_tls, force_no_tls, host, port) + base_url = compute_base_url(force_tls, force_no_tls, host, port) models_list = models.split(",") # Use 'models_list' to avoid shadowing @@ -515,6 +523,82 @@ def workerup( typer.Exit(code=1) +@cli.command(help="Deploy worker to Arcade Cloud", rich_help_panel="Deployment") +def deploy( + deployment_file: str = typer.Option( + "worker.toml", "--deployment-file", "-d", help="The deployment file to deploy." + ), + cloud_host: str = typer.Option( + PROD_CLOUD_HOST, + "--cloud-host", + "-c", + help="The Arcade Cloud host to deploy to.", + hidden=True, + ), + cloud_port: int = typer.Option( + None, + "--cloud-port", + "-cp", + help="The port of the Arcade Cloud host.", + hidden=True, + ), + host: str = typer.Option( + PROD_ENGINE_HOST, + "--host", + "-h", + help="The Arcade Engine host to register the worker to.", + ), + port: int = typer.Option( + None, + "--port", + "-p", + help="The port of the Arcade Engine host.", + ), + force_tls: bool = typer.Option( + False, + "--tls", + help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.", + ), + force_no_tls: bool = typer.Option( + False, + "--no-tls", + help="Whether to disable TLS for the connection to the Arcade Engine.", + ), +) -> None: + """ + Deploy a worker to Arcade Cloud. + """ + + config = validate_and_get_config() + engine_url = compute_base_url(force_tls, force_no_tls, host, port) + engine_client = Arcade(api_key=config.api.key, base_url=engine_url) + cloud_url = compute_base_url(force_tls, force_no_tls, cloud_host, cloud_port) + cloud_client = httpx.Client( + base_url=cloud_url, headers={"Authorization": f"Bearer {config.api.key}"} + ) + + # Fetch deployment configuration + try: + deployment = Deployment.from_toml(Path(deployment_file)) + except Exception as e: + console.print(f"❌ Failed to parse deployment file: {e}", style="bold red") + raise typer.Exit(code=1) + + with console.status(f"Deploying {len(deployment.worker)} workers"): + for worker in deployment.worker: + console.log(f"Deploying '{worker.config.id}...'", style="dim") + try: + # Attempt to deploy worker + response = worker.request().execute(cloud_client, engine_client) + parse_deployment_response(response) + console.log(f"✅ Worker '{worker.config.id}' deployed successfully.", style="dim") + except Exception as e: + console.log( + f"❌ Failed to deploy worker '{worker.config.id}': {e}", style="bold red" + ) + raise typer.Exit(code=1) + + @cli.callback() def main_callback( ctx: typer.Context, @@ -527,7 +611,7 @@ def main_callback( help="Print version and exit.", ), ) -> None: - excluded_commands = {login.__name__, logout.__name__, workerup.__name__} + excluded_commands = {login.__name__, logout.__name__, serve.__name__} if ctx.invoked_subcommand in excluded_commands: return diff --git a/arcade/arcade/cli/new.py b/arcade/arcade/cli/new.py index 81170414..57e09211 100644 --- a/arcade/arcade/cli/new.py +++ b/arcade/arcade/cli/new.py @@ -9,6 +9,11 @@ import typer from jinja2 import Environment, FileSystemLoader, select_autoescape from rich.console import Console +from arcade.worker.config.deployment import ( + create_demo_deployment, + update_deployment_with_local_packages, +) + console = Console() # Retrieve the installed version of arcade-ai @@ -111,7 +116,7 @@ def create_new_toolkit(output_directory: str) -> None: "toolkit_description": toolkit_description, "toolkit_author_name": toolkit_author_name, "toolkit_author_email": toolkit_author_email, - "arcade_version": f"{ARCADE_VERSION.rsplit('.', 1)[0]}.*", + "arcade_version": f"^{ARCADE_VERSION}", "creation_year": datetime.now().year, } template_directory = Path(__file__).parent.parent / "templates" / "{{ toolkit_name }}" @@ -123,6 +128,15 @@ def create_new_toolkit(output_directory: str) -> None: try: create_package(env, template_directory, toolkit_directory, context) + create_deployment(toolkit_directory, toolkit_name) except Exception: remove_toolkit(toolkit_directory, toolkit_name) raise + + +def create_deployment(toolkit_directory: Path, toolkit_name: str) -> None: + worker_toml = toolkit_directory / "worker.toml" + if not worker_toml.exists(): + create_demo_deployment(worker_toml, toolkit_name) + else: + update_deployment_with_local_packages(worker_toml, toolkit_name) diff --git a/arcade/arcade/cli/utils.py b/arcade/arcade/cli/utils.py index 8e72c79e..3045843d 100644 --- a/arcade/arcade/cli/utils.py +++ b/arcade/arcade/cli/utils.py @@ -80,7 +80,7 @@ def create_cli_catalog( return catalog -def compute_engine_base_url( +def compute_base_url( force_tls: bool, force_no_tls: bool, host: str, @@ -193,7 +193,7 @@ def get_tools_from_engine( toolkit: str | None = None, ) -> list[ToolDefinition]: config = validate_and_get_config() - base_url = compute_engine_base_url(force_tls, force_no_tls, host, port) + base_url = compute_base_url(force_tls, force_no_tls, host, port) client = Arcade(api_key=config.api.key, base_url=base_url) tools = [] diff --git a/arcade/arcade/cli/worker.py b/arcade/arcade/cli/worker.py new file mode 100644 index 00000000..167f693b --- /dev/null +++ b/arcade/arcade/cli/worker.py @@ -0,0 +1,379 @@ +import httpx +import typer +from arcadepy import Arcade, NotFoundError +from rich.console import Console +from rich.table import Table + +from arcade.cli.constants import ( + PROD_CLOUD_HOST, + PROD_ENGINE_HOST, +) +from arcade.cli.utils import ( + OrderCommands, + compute_base_url, + validate_and_get_config, +) + +console = Console() + + +app = typer.Typer( + cls=OrderCommands, + add_completion=False, + no_args_is_help=True, + pretty_exceptions_enable=False, + pretty_exceptions_show_locals=False, + pretty_exceptions_short=True, +) + +state = { + "engine_url": compute_base_url( + host=PROD_ENGINE_HOST, port=None, force_tls=False, force_no_tls=False + ) +} + + +@app.callback() +def main( + host: str = typer.Option( + PROD_ENGINE_HOST, + "--host", + "-h", + help="The Arcade Engine host.", + ), + port: int = typer.Option( + None, + "--port", + "-p", + help="The port of the Arcade Engine host.", + ), + force_tls: bool = typer.Option( + False, + "--tls", + help="Whether to force TLS for the connection to the Arcade Engine.", + ), + force_no_tls: bool = typer.Option( + False, + "--no-tls", + help="Whether to disable TLS for the connection to the Arcade Engine.", + ), +) -> None: + """ + Manage users in the system. + """ + engine_url = compute_base_url(force_tls, force_no_tls, host, port) + state["engine_url"] = engine_url + + +@app.command("list", help="List all workers") +def list_workers( + cloud_host: str = typer.Option( + PROD_CLOUD_HOST, + "--cloud-host", + "-c", + help="The Arcade Engine host.", + hidden=True, + ), + cloud_port: int = typer.Option( + None, + "--cloud-port", + "-cp", + help="The port of the Arcade Engine host.", + hidden=True, + ), + force_tls: bool = typer.Option( + False, + "--tls", + help="Whether to force TLS for the connection to the Arcade Engine.", + hidden=True, + ), + force_no_tls: bool = typer.Option( + False, + "--no-tls", + help="Whether to disable TLS for the connection to the Arcade Engine.", + hidden=True, + ), +) -> None: + config = validate_and_get_config() + engine_url = state["engine_url"] + client = Arcade(api_key=config.api.key, base_url=engine_url) + deployments = [] + try: + cloud_url = compute_base_url(force_tls, force_no_tls, cloud_host, cloud_port) + cloud_client = httpx.Client(base_url=cloud_url) + response = cloud_client.get( + "/api/v1/workers", headers={"Authorization": f"Bearer {config.api.key}"} + ) + response.raise_for_status() + deployments = response.json()["data"]["workers"] + except Exception as e: + console.log(f"Failed to get cloud deployments: {e}") + + print_worker_table(client, deployments) + + +def print_worker_table(client: Arcade, deployments: list[dict]) -> None: + workers = client.workers.list() + if not workers.items: + console.print("No workers found", style="bold red") + return + + # Create and print a table of worker information + table = Table(title="Workers") + table.add_column("ID") + table.add_column("Cloud Deployed") + table.add_column("Engine Registered") + table.add_column("Enabled") + table.add_column("Host") + table.add_column("Toolkits") + + # Track workers that are registered in the engine + engine_workers = [] + for worker in workers.items: + if worker.id is None: + continue + engine_workers.append(worker.id) + # Check if the worker is deployed in the cloud + is_deployed = is_cloud_deployment(worker.id, deployments) + # Get the toolkits for the worker + + tools = get_toolkits(client, worker.id) + uri = worker.http.uri if worker.http and worker.http.uri else "" + table.add_row( + worker.id, + str(is_deployed), + str(True), + str(worker.enabled), + compare_endpoints(worker.id, uri, deployments), + "Could not fetch toolkits" if tools == "" else tools, + ) + for deployment in deployments: + if deployment["name"] not in engine_workers: + table.add_row(deployment["name"], "True", "False", "False", deployment["endpoint"], "") + console.print(table) + + +# Check if the worker is in the list of cloud deployments +def is_cloud_deployment(name: str, deployments: list[dict]) -> bool: + return any(deployment["name"] == name for deployment in deployments) + + +# Compare the endpoint of the worker in the engine to the endpoint in the cloud +# Return a highlighted diff if the endpoint in the engine is different from the endpoint in the cloud +def compare_endpoints(worker_id: str, engine_endpoint: str, deployments: list[dict]) -> str: + if is_cloud_deployment(worker_id, deployments): + for deployment in deployments: + deployment_endpoint = deployment["endpoint"] + if deployment_endpoint == engine_endpoint: + return engine_endpoint + return f"[red]Endpoint Mismatch[/red]\n[yellow]Registered Endpoint: {engine_endpoint}[/yellow]\n[green]Actual Endpoint: {deployment_endpoint}[/green]" + return engine_endpoint + + +def parse_deployment_response(response: dict) -> None: + # Check what changes were made to the worker and display + changes = response["data"]["changes"] + additions = changes.get("additions", []) + removals = changes.get("removals", []) + updates = changes.get("updates", []) + no_changes = changes.get("no_changes", []) + print_deployment_table(additions, removals, updates, no_changes) + + +def print_deployment_table( + additions: list, removals: list, updates: list, no_changes: list +) -> None: + table = Table(title="Changed Packages") + table.add_column("Added", justify="right", style="green") + table.add_column("Removed", justify="right", style="red") + table.add_column("Updated", justify="right", style="yellow") + table.add_column("No Changes", justify="right", style="dim") + max_rows = max(len(additions), len(removals), len(updates), len(no_changes)) + + # Add each row of worker package changes to the table + for i in range(max_rows): + addition = additions[i] if i < len(additions) else "" + removal = removals[i] if i < len(removals) else "" + update = updates[i] if i < len(updates) else "" + no_change = no_changes[i] if i < len(no_changes) else "" + table.add_row(addition, removal, update, no_change) + console.print(table) + + +@app.command("enable", help="Enable a worker") +def enable_worker( + worker_id: str, +) -> None: + config = validate_and_get_config() + engine_url = state["engine_url"] + arcade = Arcade(api_key=config.api.key, base_url=engine_url) + try: + arcade.workers.update(worker_id, enabled=True) + except Exception as e: + console.print(f"Error enabling worker: {e}", style="bold red") + raise typer.Exit(code=1) + + +@app.command("disable", help="Disable a worker") +def disable_worker( + worker_id: str, +) -> None: + config = validate_and_get_config() + engine_url = state["engine_url"] + arcade = Arcade(api_key=config.api.key, base_url=engine_url) + try: + arcade.workers.update(worker_id, enabled=False) + except Exception as e: + console.print(f"Error disabling worker: {e}", style="bold red") + raise typer.Exit(code=1) + + +@app.command("rm", help="Remove a worker") +def rm_worker( + worker_id: str, + engine_only: bool = typer.Option( + False, + "--deregister", + "-d", + help="Deregister the worker from the engine", + ), + cloud_host: str = typer.Option( + PROD_CLOUD_HOST, + "--cloud-host", + "-c", + help="The Arcade Engine host.", + hidden=True, + ), + cloud_port: int = typer.Option( + None, + "--cloud-port", + "-cp", + help="The port of the Arcade Engine host.", + hidden=True, + ), + force_tls: bool = typer.Option( + False, + "--tls", + help="Whether to force TLS for the connection to the Arcade Engine.", + hidden=True, + ), + force_no_tls: bool = typer.Option( + False, + "--no-tls", + help="Whether to disable TLS for the connection to the Arcade Engine.", + hidden=True, + ), +) -> None: + config = validate_and_get_config() + engine_url = state["engine_url"] + cloud_url = compute_base_url(force_tls, force_no_tls, cloud_host, cloud_port) + + # First attempt to delete from the cloud + if not engine_only: + try: + client = httpx.Client() + response = client.delete( + f"{cloud_url}/api/v1/workers/{worker_id}", + headers={"Authorization": f"Bearer {config.api.key}"}, + ) + response.raise_for_status() + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + console.print( + "Deployment not found. To deregister the worker from the engine, use the --deregister flag.", + style="bold red", + ) + raise typer.Exit(code=1) + else: + console.print(f"Error deleting deployment: {e}", style="bold red") + raise typer.Exit(code=1) + except Exception as e: + console.print(f"Error deleting deployment: {e}", style="bold red") + raise typer.Exit(code=1) + + # Then try to delete from the engine + try: + arcade = Arcade(api_key=config.api.key, base_url=engine_url) + arcade.workers.delete(worker_id) + except NotFoundError: + console.print("Worker not found", style="bold red") + except Exception as e: + console.print(f"Error deleting worker from engine: {e}", style="bold red") + raise typer.Exit(code=1) + + +@app.command("logs", help="Get logs for a worker") +def worker_logs( + worker_id: str, + cloud_host: str = typer.Option( + PROD_CLOUD_HOST, + "--cloud-host", + "-c", + help="The Arcade Engine host.", + hidden=True, + ), + cloud_port: int = typer.Option( + None, + "--cloud-port", + "-cp", + help="The port of the Arcade Engine host.", + hidden=True, + ), + force_tls: bool = typer.Option( + False, + "--tls", + help="Whether to force TLS for the connection to the Arcade Engine.", + hidden=True, + ), + force_no_tls: bool = typer.Option( + False, + "--no-tls", + help="Whether to disable TLS for the connection to the Arcade Engine.", + hidden=True, + ), +) -> None: + config = validate_and_get_config() + cloud_url = compute_base_url(force_tls, force_no_tls, cloud_host, cloud_port) + try: + with httpx.stream( + "GET", + f"{cloud_url}/api/v1/workers/logs/{worker_id}", + headers={"Authorization": f"Bearer {config.api.key}", "Accept": "text/event-stream"}, + # allow the connection to stay open indefinitely + timeout=None, # noqa: S113 + ) as s: + for line in s.iter_lines(): + if not line or "[DONE]" in line: # Skip empty lines + continue + if "event: error" in line: + console.print("Could not stream logs", style="bold red") + if line.startswith("data:"): + # Extract just the data portion after 'data:' + data = line[5:].strip() # Remove 'data:' prefix and whitespace + console.print(data, markup=False) + except Exception as e: + console.print(f"Error connecting to log stream: {e}", style="bold red") + raise typer.Exit(code=1) + + +def get_toolkits(client: Arcade, worker_id: str | None) -> str: + if worker_id is None: + return "" + try: + # Get tools for the given worker + tools = client.workers.tools(worker_id) + toolkits: list[str] = [] + if not tools.items: + return "" + + # Get toolkit names + for page in tools.iter_pages(): + for tool in page.items: + if tool.toolkit.name not in toolkits: + toolkits.append(tool.toolkit.name) + return ", ".join(toolkits) + except NotFoundError: + return "" + except Exception as e: + console.print(f"Error getting worker tools: {e}", style="bold red") + raise typer.Exit(code=1) diff --git a/arcade/arcade/worker/config/deployment.py b/arcade/arcade/worker/config/deployment.py new file mode 100644 index 00000000..38cebc6f --- /dev/null +++ b/arcade/arcade/worker/config/deployment.py @@ -0,0 +1,294 @@ +import base64 +import io +import os +import secrets +import tarfile +from pathlib import Path +from typing import Any + +import httpx +import toml +from arcadepy import Arcade, NotFoundError +from httpx import Client +from packaging.requirements import Requirement +from pydantic import BaseModel, field_validator, model_validator + + +# Base class for versioned packages +class Package(BaseModel): + name: str + specifier: str | None = None + + @classmethod + def from_requirement(cls, requirement_str: str) -> "Package": + req = Requirement(requirement_str) + return cls(name=req.name, specifier=str(req.specifier) if req.specifier else None) + + +# Base class for a list of packages +class Packages(BaseModel): + packages: list[Package] + + # Convert string package i.e. "arcade>1.0.0" to a name and specifier + # Specifiers are currently unused + @field_validator("packages", mode="before") + @classmethod + def parse_package_requirements(cls, packages: list[str]) -> list[Package]: + """Convert package requirement strings to Package objects.""" + return [Package.from_requirement(pkg) for pkg in packages] + + +# Base class for a local package +class LocalPackage(BaseModel): + name: str + content: str + + +# Base class for a list of local packages +class LocalPackages(BaseModel): + packages: list[str] + + +# Custom repository configurations +class PackageRepository(Packages): + index: str + index_url: str + trusted_host: str + + +# Pypi is a special case of a package repository +class Pypi(PackageRepository): + index: str = "pypi" + index_url: str = "https://pypi.org/simple" + trusted_host: str = "pypi.org" + + +class Config(BaseModel): + id: str + enabled: bool = True + timeout: int = 30 + retries: int = 3 + secret: str + + # Validate that the secret is a non-empty string and not 'dev' + @field_validator("secret") + @classmethod + def valid_secret(cls, v: str) -> str: + if v.strip("") == "" or v == "dev": + raise ValueError("Secret must be a non-empty string and not 'dev'") + return v + + +# Cloud request for deploying a worker +class Request(BaseModel): + name: str + secret: str + enabled: bool + timeout: int + retries: int + pypi: Pypi | None = None + custom_repositories: list[PackageRepository] | None = None + local_packages: list[LocalPackage] | None = None + + def execute(self, cloud_client: Client, engine_client: Arcade) -> Any: + # Attempt to deploy worker to the cloud + try: + cloud_response = cloud_client.put( + str(cloud_client.base_url) + "/api/v1/workers", + json=self.model_dump(mode="json"), + timeout=120, + ) + cloud_response.raise_for_status() + except httpx.ConnectError as e: + raise ValueError(f"Failed to connect to cloud: {e}") + except Exception: + msg = cloud_response.json().get("msg", f"{cloud_response.status_code}: Unknown error") + raise ValueError(f"Failed to start worker: {msg}") + + try: + # Check if worker already exists + engine_client.workers.get(self.name) + engine_client.workers.update( + id=self.name, + enabled=self.enabled, + http={ + "uri": cloud_response.json()["data"]["worker_endpoint"], + "secret": self.secret, + "timeout": self.timeout, + "retry": self.retries, + }, + ) + # If worker does not exist, create it + except NotFoundError: + engine_client.workers.create( + id=self.name, + enabled=self.enabled, + http={ + "uri": cloud_response.json()["data"]["worker_endpoint"], + "secret": self.secret, + "timeout": self.timeout, + "retry": self.retries, + }, + ) + + except Exception as e: + raise ValueError(f"Failed to add worker to engine: {e}") + + return cloud_response.json() + + +class Worker(BaseModel): + toml_path: Path + config: Config + pypi_source: Pypi | None = None + custom_source: list[PackageRepository] | None = None + local_source: LocalPackages | None = None + + def request(self) -> Request: + """Convert Deployment to a Request object.""" + self.validate_packages() + self.compress_local_packages() + return Request( + name=self.config.id, + secret=self.config.secret, + enabled=self.config.enabled, + timeout=self.config.timeout, + retries=self.config.retries, + pypi=self.pypi_source, + custom_repositories=self.custom_source, + local_packages=self.compress_local_packages(), + ) + + # Search for local packages and compress the source code to send + def compress_local_packages(self) -> list[LocalPackage] | None: + """Compress local packages into a list of LocalPackage objects.""" + if self.local_source is None: + return None + + # Compress local packages into a list of LocalPackage objects + def process_package(package_path_str: str) -> LocalPackage: + package_path = self.toml_path.parent / package_path_str + + if not package_path.exists(): + raise FileNotFoundError(f"Local package not found: {package_path}") + if not package_path.is_dir(): + raise FileNotFoundError(f"Local package is not a directory: {package_path}") + + # Check that the package is a valid python package + if ( + not (package_path / "pyproject.toml").is_file() + and not (package_path / "setup.py").is_file() + ): + raise ValueError( + f"package '{package_path}' must contain a pyproject.toml or setup.py file" + ) + + # Compress the package into a byte stream and tar + byte_stream = io.BytesIO() + with tarfile.open(fileobj=byte_stream, mode="w:gz") as tar: + tar.add(package_path, arcname=package_path.name) + + byte_stream.seek(0) + package_bytes = byte_stream.read() + package_bytes_b64 = base64.b64encode(package_bytes).decode("utf-8") + + return LocalPackage(name=package_path.name, content=package_bytes_b64) + + return list(map(process_package, self.local_source.packages)) + + # Validate that there are no duplicate packages for each worker + def validate_packages(self) -> None: + """Validate packages.""" + packages: list[str] = [] + if self.pypi_source: + for pypi_package in self.pypi_source.packages: + packages.append(pypi_package.name) + if self.custom_source: + for repository in self.custom_source: + for package in repository.packages: + packages.append(package.name) + if self.local_source: + for local_package in self.local_source.packages: + packages.append(os.path.normpath(Path(local_package))) + dupes = [x for n, x in enumerate(packages) if x in packages[:n]] + if dupes: + raise ValueError(f"Duplicate packages: {dupes}") + + +class Deployment(BaseModel): + toml_path: Path + worker: list[Worker] + + # Validate that there are no duplicate worker names + @model_validator(mode="after") + def validate_workers(self) -> "Deployment": + for worker in self.worker: + if sum(worker.config.id == w.config.id for w in self.worker) > 1: + raise ValueError(f"Duplicate worker name: {worker.config.id}") + return self + + # Load a deployment from a toml file + @classmethod + def from_toml(cls, toml_path: Path) -> "Deployment": + try: + with open(toml_path) as f: + toml_data = toml.load(f) + + if not toml_data: + raise ValueError(f"Empty TOML file: {toml_path}") + + # Add the toml path to each worker so relative packages can be found + if "worker" in toml_data: + for worker in toml_data["worker"]: + worker["toml_path"] = toml_path + + return cls(**toml_data, toml_path=toml_path) + + except toml.TomlDecodeError as e: + raise ValueError(f"Invalid TOML format in {toml_path}: {e!s}") + except FileNotFoundError: + raise FileNotFoundError(f"Config file not found: {toml_path}") + + # Save the deployment to a toml file + def save(self) -> None: + print("writing deployment file", self.toml_path) + with open(self.toml_path, "w") as f: + data = self.model_dump() + # Remove the toml_path from the deployment file + del data["toml_path"] + for worker in data["worker"]: + del worker["toml_path"] + toml.dump(data, f) + + +# Create a demo deployment file with one worker +def create_demo_deployment(toml_path: Path, toolkit_name: str) -> None: + """Create a deployment from a toml file.""" + deployment = Deployment( + toml_path=toml_path, + worker=[ + Worker( + toml_path=toml_path, + config=Config( + id="demo-worker", + enabled=True, + timeout=30, + retries=3, + secret=secrets.token_hex(16), + ), + local_source=LocalPackages(packages=[f"./{toolkit_name}"]), + ) + ], + ) + deployment.save() + + +# Get a currently existing deployment and add an additional local package +def update_deployment_with_local_packages(toml_path: Path, toolkit_name: str) -> None: + """Update a deployment from a toml file.""" + deployment = Deployment.from_toml(toml_path) + if deployment.worker[0].local_source is None: + deployment.worker[0].local_source = LocalPackages(packages=[f"./{toolkit_name}"]) + else: + deployment.worker[0].local_source.packages.append(f"./{toolkit_name}") + deployment.save() diff --git a/arcade/pyproject.toml b/arcade/pyproject.toml index fc0a58e3..c043c51d 100644 --- a/arcade/pyproject.toml +++ b/arcade/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "arcade-ai" -version = "1.0.5" +version = "1.1.0" description = "Arcade Python SDK and CLI" readme = "README.md" packages = [ @@ -21,12 +21,14 @@ rich = "^13.7.1" Jinja2 = ">=3.1.5,<4.0.0" pyyaml = "^6.0" openai = "^1.36.0" # TODO: relax to an earlier version that still has what we need -arcadepy = "^1.0.0" +arcadepy = "^1.3.1" pyjwt = "^2.8.0" loguru = "^0.7.0" tqdm = "^4.1.0" +toml = "^0.10.2" types-python-dateutil = "2.9.0.20241003" types-pytz = "2024.2.0.20241003" +types-toml = "0.10.8.20240310" opentelemetry-instrumentation-fastapi = "0.48b0" opentelemetry-exporter-otlp-proto-http = "1.27.0" opentelemetry-exporter-otlp-proto-common = "1.27.0" @@ -42,6 +44,7 @@ pyreadline3 = {version = "^3.5.4", platform = "win32"} [tool.poetry.extras] evals = ["scipy", "numpy", "scikit-learn", "pytz", "python-dateutil"] + [tool.poetry.group.dev.dependencies] pytest = "^8.1.2" pytest-cov = "^4.0.0" diff --git a/arcade/tests/cli/test_utils.py b/arcade/tests/cli/test_utils.py index 48fa0425..9328409a 100644 --- a/arcade/tests/cli/test_utils.py +++ b/arcade/tests/cli/test_utils.py @@ -1,6 +1,6 @@ import pytest -from arcade.cli.utils import compute_engine_base_url, compute_login_url +from arcade.cli.utils import compute_base_url, compute_login_url DEFAULT_CLOUD_HOST = "cloud.arcade.dev" DEFAULT_ENGINE_HOST = "api.arcade.dev" @@ -186,7 +186,7 @@ DEFAULT_FORCE_NO_TLS = False ], ) def test_compute_base_url(inputs: dict, expected_output: str): - base_url = compute_engine_base_url( + base_url = compute_base_url( inputs["force_tls"], inputs["force_no_tls"], inputs["host_input"], diff --git a/arcade/tests/deployment/test_config.py b/arcade/tests/deployment/test_config.py new file mode 100644 index 00000000..923626ff --- /dev/null +++ b/arcade/tests/deployment/test_config.py @@ -0,0 +1,227 @@ +# Ignore hardcoded secret linting +# ruff: noqa: S105 +# ruff: noqa: S106 +import json +from pathlib import Path + +import pytest + +from arcade.worker.config.deployment import ( + Config, + Deployment, + LocalPackages, + Package, + PackageRepository, + Pypi, + Worker, +) + + +@pytest.fixture +def test_dir(): + return Path(__file__).parent + + +def test_invalid_toml_path(test_dir): + with pytest.raises(FileNotFoundError): + Deployment.from_toml(test_dir / "test_files" / "invalid.toml") + + +def test_missing_fields(test_dir): + with pytest.raises(ValueError): + Deployment.from_toml(test_dir / "test_files" / "invalid.fields.worker.toml") + + +def test_deployment_parsing(test_dir): + config_path = test_dir / "test_files" / "full.worker.toml" + deployment = Deployment.from_toml(config_path) + + # Test config section + assert deployment.worker[0].config.id == "test" + assert deployment.worker[0].config.enabled is True + assert deployment.worker[0].config.timeout == 10 + assert deployment.worker[0].config.retries == 3 + assert deployment.worker[0].config.secret == "test-secret" + + # Test pypi section + assert deployment.worker[0].pypi_source.packages == [Package(name="arcade-x")] + + # Test local_packages section + assert deployment.worker[0].local_source.packages == ["./mock_toolkit"] + + # Test custom_repositories section + repo = deployment.worker[0].custom_source[0] + assert repo.index == "pypi" + assert repo.index_url == "https://pypi.org/simple" + assert repo.trusted_host == "pypi.org" + assert repo.packages == [Package(name="arcade-ai", specifier=">=1.0.0")] + + repo = deployment.worker[0].custom_source[1] + assert repo.index == "pypi2" + assert repo.index_url == "https://pypi2.org/simple" + assert repo.trusted_host == "pypi2.org" + assert repo.packages == [Package(name="arcade-slack")] + + +def test_specifier(): + from packaging.requirements import Requirement + + req = Requirement("arcade-ai>=1.0.0") + assert req.name == "arcade-ai" + assert req.specifier == ">=1.0.0" + + +def test_deployment_dict(test_dir): + config_path = test_dir / "test_files" / "full.worker.toml" + deployment = Deployment.from_toml(config_path) + expected = json.loads("""{ + "name": "test", + "secret": "test-secret", + "enabled": true, + "timeout": 10, + "retries": 3, + "pypi": { + "packages": [ + { + "name": "arcade-x", + "specifier": null + } + ], + "index": "pypi", + "index_url": "https://pypi.org/simple", + "trusted_host": "pypi.org" + }, + "custom_repositories": [ + { + "packages": [ + { + "name": "arcade-ai", + "specifier": ">=1.0.0" + } + ], + "index": "pypi", + "index_url": "https://pypi.org/simple", + "trusted_host": "pypi.org" + }, + { + "packages": [ + { + "name": "arcade-slack", + "specifier": null + } + ], + "index": "pypi2", + "index_url": "https://pypi2.org/simple", + "trusted_host": "pypi2.org" + } + ], + "local_packages": [ + { + "name": "mock_toolkit", + "content": "H4sIAOgdymcC/+2XwWuDMBTGPftXZDltMNIkJtrCOrpbL4PdSxmiKXNVIzHt6n+/OAvtNrqbMur7Xd7j5YGH5Ps+JBMyWbzEh6WKU2W8XqAdlyqlgTj17ZxRzriHDt4A7GobG/d5b5zwKSpsVqg5iwRjs6kMBJEzMYtC7nvA1VPoZPtqtc63mZ14/ek/krKrYVcp/655JtyLY4wHNHL6D5iMPCSH1H+dGtX84YBubbO5vvsn4P/g/+f+LyihPBSUSfD/sfl/EWclqZo+9B8Kcdn/eXTyf+bmTAjp9E+H1P9I/b8yWWlv8VLlub5HH9rk6Q2+A+mPhf+R/8Hv/GeQ/4Pkf/Qj/3lEpAgCOQUPGF3+V01l9LtKLLG6yAfLf07F2f9fq/+QhhTyfwhW7d2TSitrmrVfxoVCc4TPXwX298rUmS7bA0oYodhPVZ2YrLLH6bNbR8d1tNEGPZnExQn2451906Z2OyvczdBDqvaL+Ksnrn3EazAaAAAAAAAAAAAAAAAAAOiJT7MTVu0AKAAA" + } + ] +}""") + got = deployment.worker[0].request().model_dump(mode="json") + print(got) + # Remove encoding part that contains the content + got["local_packages"][0].pop("content") + expected["local_packages"][0].pop("content") + + assert got == expected + + +def test_invalid_secret_parsing(test_dir): + config_path = test_dir / "test_files" / "invalid.secret.worker.toml" + with pytest.raises(ValueError): + Deployment.from_toml(config_path) + + +def test_missing_local_package(test_dir): + config_path = test_dir / "test_files" / "invalid.localfile.worker.toml" + deployment = Deployment.from_toml(config_path) + with pytest.raises(FileNotFoundError): + deployment.worker[0].request() + + +def test_invalid_local_package(test_dir): + config_path = test_dir / "test_files" / "invalid.localfile.worker.toml" + deployment = Deployment.from_toml(config_path) + with pytest.raises(FileNotFoundError): + deployment.worker[1].request() + + +def test_unconfigured_local_package(test_dir): + config_path = test_dir / "test_files" / "invalid.localfile.worker.toml" + deployment = Deployment.from_toml(config_path) + with pytest.raises(ValueError): + deployment.worker[2].request() + + +def test_duplicate_pypi_packages(): + worker = Worker( + toml_path=Path(__file__), + config=Config(id="test", secret="test-secret"), + pypi_source=Pypi(packages=["arcade-slack", "arcade-slack"]), + ) + with pytest.raises(ValueError): + worker.validate_packages() + + +def test_duplicate_custom_repository_packages(): + worker = Worker( + toml_path=Path(__file__), + config=Config(id="test", secret="test-secret"), + custom_source=[ + PackageRepository( + index="pypi", + index_url="https://pypi.org/simple", + trusted_host="pypi.org", + packages=["arcade-slack", "arcade-slack"], + ) + ], + ) + with pytest.raises(ValueError): + worker.validate_packages() + + +def test_duplicate_local_packages(): + worker = Worker( + toml_path=Path(__file__), + config=Config(id="test", secret="test-secret"), + local_source=LocalPackages(packages=["./mock_toolkit", "./mock_toolkit"]), + ) + with pytest.raises(ValueError): + worker.validate_packages() + + +def test_duplicate_all_typed_packages(): + worker = Worker( + toml_path=Path(__file__), + config=Config(id="test", secret="test-secret"), + pypi_source=Pypi(packages=["arcade-slack"]), + custom_source=[ + PackageRepository( + index="pypi", + index_url="https://pypi.org/simple", + trusted_host="pypi.org", + packages=["arcade-slack", "arcade-x"], + ) + ], + local_source=LocalPackages(packages=["./arcade-x"]), + ) + with pytest.raises(ValueError): + worker.validate_packages() + + +def test_duplicate_worker_names(): + worker = Worker( + toml_path=Path(__file__), + config=Config(id="test", secret="test-secret"), + ) + worker2 = Worker( + toml_path=Path(__file__), + config=Config(id="test", secret="test-secret"), + ) + with pytest.raises(ValueError): + Deployment(workers=[worker, worker2]) diff --git a/arcade/tests/deployment/test_files/full.worker.toml b/arcade/tests/deployment/test_files/full.worker.toml new file mode 100644 index 00000000..095dbdee --- /dev/null +++ b/arcade/tests/deployment/test_files/full.worker.toml @@ -0,0 +1,26 @@ + +[[worker]] +[worker.config] +id = "test" +enabled = true +timeout = 10 +retries = 3 +secret = "test-secret" + +[worker.pypi_source] +packages = ["arcade-x"] + +[worker.local_source] +packages = ["./mock_toolkit"] + +[[worker.custom_source]] +index = "pypi" +index_url = "https://pypi.org/simple" +trusted_host = "pypi.org" +packages = ["arcade-ai>=1.0.0"] + +[[worker.custom_source]] +index = "pypi2" +index_url = "https://pypi2.org/simple" +trusted_host = "pypi2.org" +packages = ["arcade-slack"] diff --git a/arcade/tests/deployment/test_files/invalid.fields.worker.toml b/arcade/tests/deployment/test_files/invalid.fields.worker.toml new file mode 100644 index 00000000..bdecbcc0 --- /dev/null +++ b/arcade/tests/deployment/test_files/invalid.fields.worker.toml @@ -0,0 +1,3 @@ + +[[worker]] +[worker.config] diff --git a/arcade/tests/deployment/test_files/invalid.localfile.worker.toml b/arcade/tests/deployment/test_files/invalid.localfile.worker.toml new file mode 100644 index 00000000..1397eea3 --- /dev/null +++ b/arcade/tests/deployment/test_files/invalid.localfile.worker.toml @@ -0,0 +1,42 @@ + +[[worker]] +[worker.config] +id = "test" +enabled = true +timeout = 10 +retries = 3 +secret = "test-secret" + +[worker.pypi_source] +packages = ["arcade-ai"] + +[worker.local_source] +packages = ["./missing_toolkit"] + +[[worker]] +[worker.config] +id = "test-2" +enabled = true +timeout = 10 +retries = 3 +secret = "test-secret" + +[worker.pypi_source] +packages = ["arcade-ai"] + +[worker.local_source] +packages = ["./invalid.localfile.worker.toml"] + +[[worker]] +[worker.config] +id = "test-3" +enabled = true +timeout = 10 +retries = 3 +secret = "test-secret" + +[worker.pypi_source] +packages = ["arcade-ai"] + +[worker.local_source] +packages = ["./invalid_toolkit"] diff --git a/arcade/tests/deployment/test_files/invalid.secret.worker.toml b/arcade/tests/deployment/test_files/invalid.secret.worker.toml new file mode 100644 index 00000000..4b6697c4 --- /dev/null +++ b/arcade/tests/deployment/test_files/invalid.secret.worker.toml @@ -0,0 +1,7 @@ +[[worker]] +[worker.config] +id = "test" +enabled = true +timeout = 10 +retries = 3 +secret = "dev" diff --git a/arcade/tests/deployment/test_files/invalid_toolkit/invalid_main.py b/arcade/tests/deployment/test_files/invalid_toolkit/invalid_main.py new file mode 100644 index 00000000..e69de29b diff --git a/arcade/tests/deployment/test_files/mock_toolkit/mock_main.py b/arcade/tests/deployment/test_files/mock_toolkit/mock_main.py new file mode 100644 index 00000000..f7cf60e1 --- /dev/null +++ b/arcade/tests/deployment/test_files/mock_toolkit/mock_main.py @@ -0,0 +1 @@ +print("Hello, world!") diff --git a/arcade/tests/deployment/test_files/mock_toolkit/pyproject.toml b/arcade/tests/deployment/test_files/mock_toolkit/pyproject.toml new file mode 100644 index 00000000..db164e63 --- /dev/null +++ b/arcade/tests/deployment/test_files/mock_toolkit/pyproject.toml @@ -0,0 +1,5 @@ +[tool.poetry] +name = "mock_toolkit" +version = "0.1.0" +description = "Mock toolkit for Arcade" +authors = ["Arcade "] diff --git a/worker.toml b/worker.toml new file mode 100644 index 00000000..bbe742f7 --- /dev/null +++ b/worker.toml @@ -0,0 +1,14 @@ +### Worker 1 +[[worker]] +[worker.config] +id = "worker-1" +enabled = true +timeout = 10 +retries = 3 +secret = "test-secret" + +[worker.pypi_source] +packages = ["arcade-slack", "arcade-x", "arcade-github"] + +[worker.local_source] +packages = ["./toolkits/spotify"]