import base64 import io import os import re import secrets import tarfile import time from pathlib import Path from typing import Any import toml from arcade_core import Toolkit from arcade_core.toolkit import Validate from arcadepy import Arcade, NotFoundError from httpx import Client, ConnectError, HTTPStatusError, TimeoutException from packaging.requirements import Requirement from pydantic import BaseModel, field_serializer, field_validator, model_validator from rich.console import Console from rich.table import Table console = Console() # Base class for versioned packages class Package(BaseModel): name: str specifier: str | None = None @classmethod def from_requirement(cls, requirement_str: str) -> "Package": req = Requirement(requirement_str) return cls(name=req.name, specifier=str(req.specifier) if req.specifier else None) # Base class for a list of packages class Packages(BaseModel): packages: list[Package] # Convert string package i.e. "arcade>1.0.0" to a name and specifier # Specifiers are currently unused @field_validator("packages", mode="before") @classmethod def parse_package_requirements(cls, packages: list[str]) -> list[Package]: """Convert package requirement strings to Package objects.""" return [Package.from_requirement(pkg) for pkg in packages] # Base class for a local package class LocalPackage(BaseModel): name: str content: str # Base class for a list of local packages class LocalPackages(BaseModel): packages: list[str] # Custom repository configurations class PackageRepository(Packages): index: str index_url: str trusted_host: str # Pypi is a special case of a package repository class Pypi(PackageRepository): index: str = "pypi" index_url: str = "https://pypi.org/simple" trusted_host: str = "pypi.org" class Secret(BaseModel): value: str pattern: str | None = None class Config(BaseModel): id: str enabled: bool = True timeout: int = 30 retries: int = 3 secret: Secret | None = None # Validate and parse the secret if required @field_validator("secret", mode="before") @classmethod def valid_secret(cls, v: str | Secret | None) -> Secret: # If the secret is a string, attempt to parse it as an environment variable or return the secret if isinstance(v, str): secret = get_env_secret(v) # If the secret has been manually set, return it elif isinstance(v, Secret): secret = v else: raise TypeError("Secret must be a string or a Secret object") # Check that the secret is not the default dev secret or empty if secret.value.strip() == "" or secret.value == "dev": raise ValueError("Secret must be a non-empty string and not 'dev'") return secret @field_serializer("secret") def serialize_secret(self, secret: Secret) -> str: if secret.pattern: return f"$env:{secret.pattern}" else: return secret.value # Cloud request for deploying a worker class Request(BaseModel): name: str secret: Secret enabled: bool timeout: int retries: int pypi: Pypi | None = None custom_repositories: list[PackageRepository] | None = None local_packages: list[LocalPackage] | None = None wait: bool = False @field_serializer("secret") def serialize_secret(self, secret: Secret) -> str: return secret.value def poll_worker_status(self, cloud_client: Client, worker_name: str) -> Any: while True: try: worker_resp = cloud_client.get( f"{cloud_client.base_url}/api/v1/workers/{worker_name}?wait_for_completion=true", timeout=10, ) worker_resp.raise_for_status() except TimeoutException: time.sleep(1) continue except ConnectError as e: raise ValueError(f"Failed to connect to cloud: {e}") except HTTPStatusError as e: raise ValueError(f"Failed to start worker: {e.response.json()}") except Exception as e: raise ValueError(f"Failed to start worker: {e}") status = worker_resp.json()["data"]["status"] if status == "Running": return worker_resp.json()["data"] if status == "Failed": raise ValueError(f"Worker failed to start: {worker_resp.json()['data']['error']}") def execute(self, cloud_client: Client, engine_client: Arcade) -> Any: # Attempt to deploy worker to the cloud try: cloud_response = cloud_client.put( str(cloud_client.base_url) + "/api/v1/workers", json=self.model_dump(mode="json"), timeout=360, ) cloud_response.raise_for_status() except ConnectError as e: raise ValueError(f"Failed to connect to cloud: {e}") except HTTPStatusError as e: raise ValueError(f"Failed to start worker: {e.response.json()}") except Exception as e: # change this so it handles errors that aren't just from cloud raise ValueError(f"Failed to start worker: {e}") parse_deployment_response(cloud_response.json()) worker_data = self.poll_worker_status(cloud_client, self.name) try: # Check if worker already exists engine_client.workers.get(self.name) engine_client.workers.update( id=self.name, enabled=self.enabled, http={ "uri": worker_data["endpoint"], "secret": self.secret.value, "timeout": self.timeout, "retry": self.retries, }, ) # If worker does not exist, create it except NotFoundError: engine_client.workers.create( id=self.name, enabled=self.enabled, http={ "uri": worker_data["endpoint"], "secret": self.secret.value, "timeout": self.timeout, "retry": self.retries, }, ) except Exception as e: raise ValueError(f"Failed to add worker to engine: {e}") return cloud_response.json() class Worker(BaseModel): toml_path: Path config: Config pypi_source: Pypi | None = None custom_source: list[PackageRepository] | None = None local_source: LocalPackages | None = None def request(self) -> Request: """Convert Deployment to a Request object.""" self.validate_packages() self.compress_local_packages() if self.config.secret is None: raise ValueError("Secret is required") return Request( name=self.config.id, secret=self.config.secret, enabled=self.config.enabled, timeout=self.config.timeout, retries=self.config.retries, pypi=self.pypi_source, custom_repositories=self.custom_source, local_packages=self.compress_local_packages(), ) # Search for local packages and compress the source code to send def compress_local_packages(self) -> list[LocalPackage] | None: """Compress local packages into a list of LocalPackage objects.""" if self.local_source is None: return None def exclude_filter(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo | None: """Filter for files/directories to exclude from the compressed package""" if not Validate.path(tarinfo.name): return None return tarinfo # Compress local packages into a list of LocalPackage objects def process_package(package_path_str: str) -> LocalPackage: package_path = self.toml_path.parent / package_path_str if not package_path.exists(): raise FileNotFoundError(f"Local package not found: {package_path}") if not package_path.is_dir(): raise FileNotFoundError(f"Local package is not a directory: {package_path}") # Check that the package is a valid python package if ( not (package_path / "pyproject.toml").is_file() and not (package_path / "setup.py").is_file() ): raise ValueError( f"package '{package_path}' must contain a pyproject.toml or setup.py file" ) # Validate that we are able to load the package Toolkit.tools_from_directory(package_dir=package_path, package_name=package_path.name) # Compress the package into a byte stream and tar byte_stream = io.BytesIO() with tarfile.open(fileobj=byte_stream, mode="w:gz") as tar: tar.add(package_path, arcname=package_path.name, filter=exclude_filter) byte_stream.seek(0) package_bytes = byte_stream.read() package_bytes_b64 = base64.b64encode(package_bytes).decode("utf-8") return LocalPackage(name=package_path.name, content=package_bytes_b64) return list(map(process_package, self.local_source.packages)) # Validate that there are no duplicate packages for each worker def validate_packages(self) -> None: """Validate packages.""" packages: list[str] = [] if self.pypi_source: for pypi_package in self.pypi_source.packages: packages.append(pypi_package.name) if self.custom_source: for repository in self.custom_source: for package in repository.packages: packages.append(package.name) if self.local_source: for local_package in self.local_source.packages: packages.append(os.path.normpath(Path(local_package))) dupes = [x for n, x in enumerate(packages) if x in packages[:n]] if dupes: raise ValueError(f"Duplicate packages: {dupes}") class Deployment(BaseModel): toml_path: Path worker: list[Worker] # Validate that there are no duplicate worker names @model_validator(mode="after") def validate_workers(self) -> "Deployment": for worker in self.worker: if sum(worker.config.id == w.config.id for w in self.worker) > 1: raise ValueError(f"Duplicate worker name: {worker.config.id}") return self # Load a deployment from a toml file @classmethod def from_toml(cls, toml_path: Path) -> "Deployment": try: with open(toml_path) as f: toml_data = toml.load(f) if not toml_data: raise ValueError(f"Empty TOML file: {toml_path}") # Add the toml path to each worker so relative packages can be found if "worker" in toml_data: for worker in toml_data["worker"]: worker["toml_path"] = toml_path return cls(**toml_data, toml_path=toml_path) except toml.TomlDecodeError as e: raise ValueError(f"Invalid TOML format in {toml_path}: {e!s}") except FileNotFoundError: raise FileNotFoundError(f"Config file not found: {toml_path}") # Save the deployment to a toml file def save(self) -> None: print("writing deployment file", self.toml_path) with open(self.toml_path, "w") as f: data = self.model_dump() # Remove the toml_path from the deployment file del data["toml_path"] for worker in data["worker"]: del worker["toml_path"] toml.dump(data, f) # Create a demo deployment file with one worker def create_demo_deployment(toml_path: Path, toolkit_name: str) -> None: """Create a deployment from a toml file.""" deployment = Deployment( toml_path=toml_path, worker=[ Worker( toml_path=toml_path, config=Config( id="demo-worker", enabled=True, timeout=30, retries=3, secret=Secret(value=secrets.token_hex(16), pattern=None), ), local_source=LocalPackages(packages=[f"./{toolkit_name}"]), ) ], ) deployment.save() # Get a currently existing deployment and add an additional local package def update_deployment_with_local_packages(toml_path: Path, toolkit_name: str) -> None: """Update a deployment from a toml file.""" deployment = Deployment.from_toml(toml_path) if deployment.worker[0].local_source is None: deployment.worker[0].local_source = LocalPackages(packages=[f"./{toolkit_name}"]) else: deployment.worker[0].local_source.packages.append(f"./{toolkit_name}") deployment.save() def get_env_secret(secret: str) -> Secret: """Parse a secret from an environment variable.""" # Check if the secret contains the "${env:}" syntax pattern = r"\${env:([^}]+)}" matches = re.findall(pattern, secret) # Only allow a single match if matches and len(matches) == 1: match = matches[0].strip() # Attempt to lookup and create the secret print(f"Looking up secret: {match}") value = os.getenv(match) if value: return Secret(value=value, pattern=match) else: raise ValueError(f"Environment variable not found: {match}") elif matches and len(matches) > 1: raise ValueError(f"Multiple environment variables found in secret: {secret}") # If no matches are found, return the secret as is return Secret(value=secret, pattern=None) def parse_deployment_response(response: dict) -> None: # Check what changes were made to the worker and display changes = response["data"]["changes"] additions = changes.get("additions", []) removals = changes.get("removals", []) updates = changes.get("updates", []) no_changes = changes.get("no_changes", []) print_deployment_table(additions, removals, updates, no_changes) def print_deployment_table( additions: list, removals: list, updates: list, no_changes: list ) -> None: table = Table(title="Changed Packages") table.add_column("Adding", justify="right", style="green") table.add_column("Removing", justify="right", style="red") table.add_column("Updating", justify="right", style="yellow") table.add_column("No Changes", justify="right", style="dim") max_rows = max(len(additions), len(removals), len(updates), len(no_changes)) # Add each row of worker package changes to the table for i in range(max_rows): addition = additions[i] if i < len(additions) else "" removal = removals[i] if i < len(removals) else "" update = updates[i] if i < len(updates) else "" no_change = no_changes[i] if i < len(no_changes) else "" table.add_row(addition, removal, update, no_change) console.print(table)