import base64 import io import os import re import secrets import tarfile import time from pathlib import Path from typing import Any import toml from arcadepy import Arcade, NotFoundError from httpx import Client, ConnectError, HTTPStatusError, TimeoutException from packaging.requirements import Requirement from pydantic import BaseModel, field_serializer, field_validator, model_validator from rich.console import Console from rich.table import Table console = Console() # Base class for versioned packages class Package(BaseModel): name: str specifier: str | None = None @classmethod def from_requirement(cls, requirement_str: str) -> "Package": req = Requirement(requirement_str) return cls(name=req.name, specifier=str(req.specifier) if req.specifier else None) # Base class for a list of packages class Packages(BaseModel): packages: list[Package] # Convert string package i.e. "arcade>1.0.0" to a name and specifier # Specifiers are currently unused @field_validator("packages", mode="before") @classmethod def parse_package_requirements(cls, packages: list[str]) -> list[Package]: """Convert package requirement strings to Package objects.""" return [Package.from_requirement(pkg) for pkg in packages] # Base class for a local package class LocalPackage(BaseModel): name: str content: str # Base class for a list of local packages class LocalPackages(BaseModel): packages: list[str] # Custom repository configurations class PackageRepository(Packages): index: str index_url: str trusted_host: str # Pypi is a special case of a package repository class Pypi(PackageRepository): index: str = "pypi" index_url: str = "https://pypi.org/simple" trusted_host: str = "pypi.org" class Secret(BaseModel): value: str pattern: str | None = None class Config(BaseModel): id: str enabled: bool = True timeout: int = 30 retries: int = 3 secret: Secret | None = None # Validate and parse the secret if required @field_validator("secret", mode="before") @classmethod def valid_secret(cls, v: str | Secret | None) -> Secret: # If the secret is a string, attempt to parse it as an environment variable or return the secret if isinstance(v, str): secret = get_env_secret(v) # If the secret has been manually set, return it elif isinstance(v, Secret): secret = v else: raise TypeError("Secret must be a string or a Secret object") # Check that the secret is not the default dev secret or empty if secret.value.strip() == "" or secret.value == "dev": raise ValueError("Secret must be a non-empty string and not 'dev'") return secret @field_serializer("secret") def serialize_secret(self, secret: Secret) -> str: if secret.pattern: return f"$env:{secret.pattern}" else: return secret.value # Cloud request for deploying a worker class Request(BaseModel): name: str secret: Secret enabled: bool timeout: int retries: int pypi: Pypi | None = None custom_repositories: list[PackageRepository] | None = None local_packages: list[LocalPackage] | None = None wait: bool = False @field_serializer("secret") def serialize_secret(self, secret: Secret) -> str: return secret.value def poll_worker_status(self, cloud_client: Client, worker_name: str) -> Any: while True: try: worker_resp = cloud_client.get( f"{cloud_client.base_url}/api/v1/workers/{worker_name}?wait_for_completion=true", timeout=10, ) worker_resp.raise_for_status() except TimeoutException: time.sleep(1) continue except ConnectError as e: raise ValueError(f"Failed to connect to cloud: {e}") except HTTPStatusError as e: raise ValueError(f"Failed to start worker: {e.response.json()}") except Exception as e: raise ValueError(f"Failed to start worker: {e}") status = worker_resp.json()["data"]["status"] if status == "Running": return worker_resp.json()["data"] if status == "Failed": raise ValueError(f"Worker failed to start: {worker_resp.json()['data']['error']}") def execute(self, cloud_client: Client, engine_client: Arcade) -> Any: # Attempt to deploy worker to the cloud try: cloud_response = cloud_client.put( str(cloud_client.base_url) + "/api/v1/workers", json=self.model_dump(mode="json"), timeout=360, ) cloud_response.raise_for_status() except ConnectError as e: raise ValueError(f"Failed to connect to cloud: {e}") except HTTPStatusError as e: raise ValueError(f"Failed to start worker: {e.response.json()}") except Exception as e: # change this so it handles errors that aren't just from cloud raise ValueError(f"Failed to start worker: {e}") parse_deployment_response(cloud_response.json()) worker_data = self.poll_worker_status(cloud_client, self.name) try: # Check if worker already exists engine_client.workers.get(self.name) engine_client.workers.update( id=self.name, enabled=self.enabled, http={ "uri": worker_data["endpoint"], "secret": self.secret.value, "timeout": self.timeout, "retry": self.retries, }, ) # If worker does not exist, create it except NotFoundError: engine_client.workers.create( id=self.name, enabled=self.enabled, http={ "uri": worker_data["endpoint"], "secret": self.secret.value, "timeout": self.timeout, "retry": self.retries, }, ) except Exception as e: raise ValueError(f"Failed to add worker to engine: {e}") return cloud_response.json() class Worker(BaseModel): toml_path: Path config: Config pypi_source: Pypi | None = None custom_source: list[PackageRepository] | None = None local_source: LocalPackages | None = None def request(self) -> Request: """Convert Deployment to a Request object.""" self.validate_packages() self.compress_local_packages() if self.config.secret is None: raise ValueError("Secret is required") return Request( name=self.config.id, secret=self.config.secret, enabled=self.config.enabled, timeout=self.config.timeout, retries=self.config.retries, pypi=self.pypi_source, custom_repositories=self.custom_source, local_packages=self.compress_local_packages(), ) # Search for local packages and compress the source code to send def compress_local_packages(self) -> list[LocalPackage] | None: """Compress local packages into a list of LocalPackage objects.""" if self.local_source is None: return None def exclude_filter(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo | None: """Filter for files/directories to exclude from the compressed package""" basename = os.path.basename(tarinfo.name) # Exclude all hidden directories/files if basename.startswith("."): return None # Exclude specific directories/files if basename in {"dist", "build", "__pycache__", "venv", "coverage.xml"}: return None # Exclude lock files if basename.endswith(".lock"): return None return tarinfo # Compress local packages into a list of LocalPackage objects def process_package(package_path_str: str) -> LocalPackage: package_path = self.toml_path.parent / package_path_str if not package_path.exists(): raise FileNotFoundError(f"Local package not found: {package_path}") if not package_path.is_dir(): raise FileNotFoundError(f"Local package is not a directory: {package_path}") # Check that the package is a valid python package if ( not (package_path / "pyproject.toml").is_file() and not (package_path / "setup.py").is_file() ): raise ValueError( f"package '{package_path}' must contain a pyproject.toml or setup.py file" ) # Compress the package into a byte stream and tar byte_stream = io.BytesIO() with tarfile.open(fileobj=byte_stream, mode="w:gz") as tar: tar.add(package_path, arcname=package_path.name, filter=exclude_filter) byte_stream.seek(0) package_bytes = byte_stream.read() package_bytes_b64 = base64.b64encode(package_bytes).decode("utf-8") return LocalPackage(name=package_path.name, content=package_bytes_b64) return list(map(process_package, self.local_source.packages)) # Validate that there are no duplicate packages for each worker def validate_packages(self) -> None: """Validate packages.""" packages: list[str] = [] if self.pypi_source: for pypi_package in self.pypi_source.packages: packages.append(pypi_package.name) if self.custom_source: for repository in self.custom_source: for package in repository.packages: packages.append(package.name) if self.local_source: for local_package in self.local_source.packages: packages.append(os.path.normpath(Path(local_package))) dupes = [x for n, x in enumerate(packages) if x in packages[:n]] if dupes: raise ValueError(f"Duplicate packages: {dupes}") class Deployment(BaseModel): toml_path: Path worker: list[Worker] # Validate that there are no duplicate worker names @model_validator(mode="after") def validate_workers(self) -> "Deployment": for worker in self.worker: if sum(worker.config.id == w.config.id for w in self.worker) > 1: raise ValueError(f"Duplicate worker name: {worker.config.id}") return self # Load a deployment from a toml file @classmethod def from_toml(cls, toml_path: Path) -> "Deployment": try: with open(toml_path) as f: toml_data = toml.load(f) if not toml_data: raise ValueError(f"Empty TOML file: {toml_path}") # Add the toml path to each worker so relative packages can be found if "worker" in toml_data: for worker in toml_data["worker"]: worker["toml_path"] = toml_path return cls(**toml_data, toml_path=toml_path) except toml.TomlDecodeError as e: raise ValueError(f"Invalid TOML format in {toml_path}: {e!s}") except FileNotFoundError: raise FileNotFoundError(f"Config file not found: {toml_path}") # Save the deployment to a toml file def save(self) -> None: print("writing deployment file", self.toml_path) with open(self.toml_path, "w") as f: data = self.model_dump() # Remove the toml_path from the deployment file del data["toml_path"] for worker in data["worker"]: del worker["toml_path"] toml.dump(data, f) # Create a demo deployment file with one worker def create_demo_deployment(toml_path: Path, toolkit_name: str) -> None: """Create a deployment from a toml file.""" deployment = Deployment( toml_path=toml_path, worker=[ Worker( toml_path=toml_path, config=Config( id="demo-worker", enabled=True, timeout=30, retries=3, secret=Secret(value=secrets.token_hex(16), pattern=None), ), local_source=LocalPackages(packages=[f"./{toolkit_name}"]), ) ], ) deployment.save() # Get a currently existing deployment and add an additional local package def update_deployment_with_local_packages(toml_path: Path, toolkit_name: str) -> None: """Update a deployment from a toml file.""" deployment = Deployment.from_toml(toml_path) if deployment.worker[0].local_source is None: deployment.worker[0].local_source = LocalPackages(packages=[f"./{toolkit_name}"]) else: deployment.worker[0].local_source.packages.append(f"./{toolkit_name}") deployment.save() def get_env_secret(secret: str) -> Secret: """Parse a secret from an environment variable.""" # Check if the secret contains the "${env:}" syntax pattern = r"\${env:([^}]+)}" matches = re.findall(pattern, secret) # Only allow a single match if matches and len(matches) == 1: match = matches[0].strip() # Attempt to lookup and create the secret print(f"Looking up secret: {match}") value = os.getenv(match) if value: return Secret(value=value, pattern=match) else: raise ValueError(f"Environment variable not found: {match}") elif matches and len(matches) > 1: raise ValueError(f"Multiple environment variables found in secret: {secret}") # If no matches are found, return the secret as is return Secret(value=secret, pattern=None) def parse_deployment_response(response: dict) -> None: # Check what changes were made to the worker and display changes = response["data"]["changes"] additions = changes.get("additions", []) removals = changes.get("removals", []) updates = changes.get("updates", []) no_changes = changes.get("no_changes", []) print_deployment_table(additions, removals, updates, no_changes) def print_deployment_table( additions: list, removals: list, updates: list, no_changes: list ) -> None: table = Table(title="Changed Packages") table.add_column("Adding", justify="right", style="green") table.add_column("Removing", justify="right", style="red") table.add_column("Updating", justify="right", style="yellow") table.add_column("No Changes", justify="right", style="dim") max_rows = max(len(additions), len(removals), len(updates), len(no_changes)) # Add each row of worker package changes to the table for i in range(max_rows): addition = additions[i] if i < len(additions) else "" removal = removals[i] if i < len(removals) else "" update = updates[i] if i < len(updates) else "" no_change = no_changes[i] if i < len(no_changes) else "" table.add_row(addition, removal, update, no_change) console.print(table)