Worker Deploy (#278)

This commit is contained in:
Sterling Dreyer 2025-03-13 09:02:36 -07:00 committed by GitHub
parent b296594863
commit a181dc5681
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1110 additions and 11 deletions

View file

@ -3,8 +3,10 @@ import os
import threading
import uuid
import webbrowser
from pathlib import Path
from typing import Any, Optional
import httpx
import typer
from arcadepy import Arcade
from arcadepy.types import AuthorizationResponse
@ -14,6 +16,7 @@ from rich.markup import escape
from rich.text import Text
from tqdm import tqdm
import arcade.cli.worker as worker
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
from arcade.cli.constants import (
CREDENTIALS_FILE_PATH,
@ -30,7 +33,7 @@ from arcade.cli.launcher import start_servers
from arcade.cli.show import show_logic
from arcade.cli.utils import (
OrderCommands,
compute_engine_base_url,
compute_base_url,
compute_login_url,
get_eval_files,
get_user_input,
@ -43,6 +46,8 @@ from arcade.cli.utils import (
validate_and_get_config,
version_callback,
)
from arcade.cli.worker import parse_deployment_response
from arcade.worker.config.deployment import Deployment
cli = typer.Typer(
cls=OrderCommands,
@ -52,6 +57,9 @@ cli = typer.Typer(
pretty_exceptions_show_locals=False,
pretty_exceptions_short=True,
)
cli.add_typer(worker.app, name="worker", help="Manage workers")
console = Console()
@ -225,7 +233,7 @@ def chat(
)
config = validate_and_get_config()
base_url = compute_engine_base_url(force_tls, force_no_tls, host, port)
base_url = compute_base_url(force_tls, force_no_tls, host, port)
client = Arcade(api_key=config.api.key, base_url=base_url)
user_email = config.user.email if config.user else None
@ -352,7 +360,7 @@ def evals(
config = validate_and_get_config()
host = PROD_ENGINE_HOST if cloud else host
base_url = compute_engine_base_url(force_tls, force_no_tls, host, port)
base_url = compute_base_url(force_tls, force_no_tls, host, port)
models_list = models.split(",") # Use 'models_list' to avoid shadowing
@ -515,6 +523,82 @@ def workerup(
typer.Exit(code=1)
@cli.command(help="Deploy worker to Arcade Cloud", rich_help_panel="Deployment")
def deploy(
deployment_file: str = typer.Option(
"worker.toml", "--deployment-file", "-d", help="The deployment file to deploy."
),
cloud_host: str = typer.Option(
PROD_CLOUD_HOST,
"--cloud-host",
"-c",
help="The Arcade Cloud host to deploy to.",
hidden=True,
),
cloud_port: int = typer.Option(
None,
"--cloud-port",
"-cp",
help="The port of the Arcade Cloud host.",
hidden=True,
),
host: str = typer.Option(
PROD_ENGINE_HOST,
"--host",
"-h",
help="The Arcade Engine host to register the worker to.",
),
port: int = typer.Option(
None,
"--port",
"-p",
help="The port of the Arcade Engine host.",
),
force_tls: bool = typer.Option(
False,
"--tls",
help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.",
),
force_no_tls: bool = typer.Option(
False,
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
),
) -> None:
"""
Deploy a worker to Arcade Cloud.
"""
config = validate_and_get_config()
engine_url = compute_base_url(force_tls, force_no_tls, host, port)
engine_client = Arcade(api_key=config.api.key, base_url=engine_url)
cloud_url = compute_base_url(force_tls, force_no_tls, cloud_host, cloud_port)
cloud_client = httpx.Client(
base_url=cloud_url, headers={"Authorization": f"Bearer {config.api.key}"}
)
# Fetch deployment configuration
try:
deployment = Deployment.from_toml(Path(deployment_file))
except Exception as e:
console.print(f"❌ Failed to parse deployment file: {e}", style="bold red")
raise typer.Exit(code=1)
with console.status(f"Deploying {len(deployment.worker)} workers"):
for worker in deployment.worker:
console.log(f"Deploying '{worker.config.id}...'", style="dim")
try:
# Attempt to deploy worker
response = worker.request().execute(cloud_client, engine_client)
parse_deployment_response(response)
console.log(f"✅ Worker '{worker.config.id}' deployed successfully.", style="dim")
except Exception as e:
console.log(
f"❌ Failed to deploy worker '{worker.config.id}': {e}", style="bold red"
)
raise typer.Exit(code=1)
@cli.callback()
def main_callback(
ctx: typer.Context,
@ -527,7 +611,7 @@ def main_callback(
help="Print version and exit.",
),
) -> None:
excluded_commands = {login.__name__, logout.__name__, workerup.__name__}
excluded_commands = {login.__name__, logout.__name__, serve.__name__}
if ctx.invoked_subcommand in excluded_commands:
return

View file

@ -9,6 +9,11 @@ import typer
from jinja2 import Environment, FileSystemLoader, select_autoescape
from rich.console import Console
from arcade.worker.config.deployment import (
create_demo_deployment,
update_deployment_with_local_packages,
)
console = Console()
# Retrieve the installed version of arcade-ai
@ -111,7 +116,7 @@ def create_new_toolkit(output_directory: str) -> None:
"toolkit_description": toolkit_description,
"toolkit_author_name": toolkit_author_name,
"toolkit_author_email": toolkit_author_email,
"arcade_version": f"{ARCADE_VERSION.rsplit('.', 1)[0]}.*",
"arcade_version": f"^{ARCADE_VERSION}",
"creation_year": datetime.now().year,
}
template_directory = Path(__file__).parent.parent / "templates" / "{{ toolkit_name }}"
@ -123,6 +128,15 @@ def create_new_toolkit(output_directory: str) -> None:
try:
create_package(env, template_directory, toolkit_directory, context)
create_deployment(toolkit_directory, toolkit_name)
except Exception:
remove_toolkit(toolkit_directory, toolkit_name)
raise
def create_deployment(toolkit_directory: Path, toolkit_name: str) -> None:
worker_toml = toolkit_directory / "worker.toml"
if not worker_toml.exists():
create_demo_deployment(worker_toml, toolkit_name)
else:
update_deployment_with_local_packages(worker_toml, toolkit_name)

View file

@ -80,7 +80,7 @@ def create_cli_catalog(
return catalog
def compute_engine_base_url(
def compute_base_url(
force_tls: bool,
force_no_tls: bool,
host: str,
@ -193,7 +193,7 @@ def get_tools_from_engine(
toolkit: str | None = None,
) -> list[ToolDefinition]:
config = validate_and_get_config()
base_url = compute_engine_base_url(force_tls, force_no_tls, host, port)
base_url = compute_base_url(force_tls, force_no_tls, host, port)
client = Arcade(api_key=config.api.key, base_url=base_url)
tools = []

379
arcade/arcade/cli/worker.py Normal file
View file

@ -0,0 +1,379 @@
import httpx
import typer
from arcadepy import Arcade, NotFoundError
from rich.console import Console
from rich.table import Table
from arcade.cli.constants import (
PROD_CLOUD_HOST,
PROD_ENGINE_HOST,
)
from arcade.cli.utils import (
OrderCommands,
compute_base_url,
validate_and_get_config,
)
console = Console()
app = typer.Typer(
cls=OrderCommands,
add_completion=False,
no_args_is_help=True,
pretty_exceptions_enable=False,
pretty_exceptions_show_locals=False,
pretty_exceptions_short=True,
)
state = {
"engine_url": compute_base_url(
host=PROD_ENGINE_HOST, port=None, force_tls=False, force_no_tls=False
)
}
@app.callback()
def main(
host: str = typer.Option(
PROD_ENGINE_HOST,
"--host",
"-h",
help="The Arcade Engine host.",
),
port: int = typer.Option(
None,
"--port",
"-p",
help="The port of the Arcade Engine host.",
),
force_tls: bool = typer.Option(
False,
"--tls",
help="Whether to force TLS for the connection to the Arcade Engine.",
),
force_no_tls: bool = typer.Option(
False,
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
),
) -> None:
"""
Manage users in the system.
"""
engine_url = compute_base_url(force_tls, force_no_tls, host, port)
state["engine_url"] = engine_url
@app.command("list", help="List all workers")
def list_workers(
cloud_host: str = typer.Option(
PROD_CLOUD_HOST,
"--cloud-host",
"-c",
help="The Arcade Engine host.",
hidden=True,
),
cloud_port: int = typer.Option(
None,
"--cloud-port",
"-cp",
help="The port of the Arcade Engine host.",
hidden=True,
),
force_tls: bool = typer.Option(
False,
"--tls",
help="Whether to force TLS for the connection to the Arcade Engine.",
hidden=True,
),
force_no_tls: bool = typer.Option(
False,
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
hidden=True,
),
) -> None:
config = validate_and_get_config()
engine_url = state["engine_url"]
client = Arcade(api_key=config.api.key, base_url=engine_url)
deployments = []
try:
cloud_url = compute_base_url(force_tls, force_no_tls, cloud_host, cloud_port)
cloud_client = httpx.Client(base_url=cloud_url)
response = cloud_client.get(
"/api/v1/workers", headers={"Authorization": f"Bearer {config.api.key}"}
)
response.raise_for_status()
deployments = response.json()["data"]["workers"]
except Exception as e:
console.log(f"Failed to get cloud deployments: {e}")
print_worker_table(client, deployments)
def print_worker_table(client: Arcade, deployments: list[dict]) -> None:
workers = client.workers.list()
if not workers.items:
console.print("No workers found", style="bold red")
return
# Create and print a table of worker information
table = Table(title="Workers")
table.add_column("ID")
table.add_column("Cloud Deployed")
table.add_column("Engine Registered")
table.add_column("Enabled")
table.add_column("Host")
table.add_column("Toolkits")
# Track workers that are registered in the engine
engine_workers = []
for worker in workers.items:
if worker.id is None:
continue
engine_workers.append(worker.id)
# Check if the worker is deployed in the cloud
is_deployed = is_cloud_deployment(worker.id, deployments)
# Get the toolkits for the worker
tools = get_toolkits(client, worker.id)
uri = worker.http.uri if worker.http and worker.http.uri else ""
table.add_row(
worker.id,
str(is_deployed),
str(True),
str(worker.enabled),
compare_endpoints(worker.id, uri, deployments),
"Could not fetch toolkits" if tools == "" else tools,
)
for deployment in deployments:
if deployment["name"] not in engine_workers:
table.add_row(deployment["name"], "True", "False", "False", deployment["endpoint"], "")
console.print(table)
# Check if the worker is in the list of cloud deployments
def is_cloud_deployment(name: str, deployments: list[dict]) -> bool:
return any(deployment["name"] == name for deployment in deployments)
# Compare the endpoint of the worker in the engine to the endpoint in the cloud
# Return a highlighted diff if the endpoint in the engine is different from the endpoint in the cloud
def compare_endpoints(worker_id: str, engine_endpoint: str, deployments: list[dict]) -> str:
if is_cloud_deployment(worker_id, deployments):
for deployment in deployments:
deployment_endpoint = deployment["endpoint"]
if deployment_endpoint == engine_endpoint:
return engine_endpoint
return f"[red]Endpoint Mismatch[/red]\n[yellow]Registered Endpoint: {engine_endpoint}[/yellow]\n[green]Actual Endpoint: {deployment_endpoint}[/green]"
return engine_endpoint
def parse_deployment_response(response: dict) -> None:
# Check what changes were made to the worker and display
changes = response["data"]["changes"]
additions = changes.get("additions", [])
removals = changes.get("removals", [])
updates = changes.get("updates", [])
no_changes = changes.get("no_changes", [])
print_deployment_table(additions, removals, updates, no_changes)
def print_deployment_table(
additions: list, removals: list, updates: list, no_changes: list
) -> None:
table = Table(title="Changed Packages")
table.add_column("Added", justify="right", style="green")
table.add_column("Removed", justify="right", style="red")
table.add_column("Updated", justify="right", style="yellow")
table.add_column("No Changes", justify="right", style="dim")
max_rows = max(len(additions), len(removals), len(updates), len(no_changes))
# Add each row of worker package changes to the table
for i in range(max_rows):
addition = additions[i] if i < len(additions) else ""
removal = removals[i] if i < len(removals) else ""
update = updates[i] if i < len(updates) else ""
no_change = no_changes[i] if i < len(no_changes) else ""
table.add_row(addition, removal, update, no_change)
console.print(table)
@app.command("enable", help="Enable a worker")
def enable_worker(
worker_id: str,
) -> None:
config = validate_and_get_config()
engine_url = state["engine_url"]
arcade = Arcade(api_key=config.api.key, base_url=engine_url)
try:
arcade.workers.update(worker_id, enabled=True)
except Exception as e:
console.print(f"Error enabling worker: {e}", style="bold red")
raise typer.Exit(code=1)
@app.command("disable", help="Disable a worker")
def disable_worker(
worker_id: str,
) -> None:
config = validate_and_get_config()
engine_url = state["engine_url"]
arcade = Arcade(api_key=config.api.key, base_url=engine_url)
try:
arcade.workers.update(worker_id, enabled=False)
except Exception as e:
console.print(f"Error disabling worker: {e}", style="bold red")
raise typer.Exit(code=1)
@app.command("rm", help="Remove a worker")
def rm_worker(
worker_id: str,
engine_only: bool = typer.Option(
False,
"--deregister",
"-d",
help="Deregister the worker from the engine",
),
cloud_host: str = typer.Option(
PROD_CLOUD_HOST,
"--cloud-host",
"-c",
help="The Arcade Engine host.",
hidden=True,
),
cloud_port: int = typer.Option(
None,
"--cloud-port",
"-cp",
help="The port of the Arcade Engine host.",
hidden=True,
),
force_tls: bool = typer.Option(
False,
"--tls",
help="Whether to force TLS for the connection to the Arcade Engine.",
hidden=True,
),
force_no_tls: bool = typer.Option(
False,
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
hidden=True,
),
) -> None:
config = validate_and_get_config()
engine_url = state["engine_url"]
cloud_url = compute_base_url(force_tls, force_no_tls, cloud_host, cloud_port)
# First attempt to delete from the cloud
if not engine_only:
try:
client = httpx.Client()
response = client.delete(
f"{cloud_url}/api/v1/workers/{worker_id}",
headers={"Authorization": f"Bearer {config.api.key}"},
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
console.print(
"Deployment not found. To deregister the worker from the engine, use the --deregister flag.",
style="bold red",
)
raise typer.Exit(code=1)
else:
console.print(f"Error deleting deployment: {e}", style="bold red")
raise typer.Exit(code=1)
except Exception as e:
console.print(f"Error deleting deployment: {e}", style="bold red")
raise typer.Exit(code=1)
# Then try to delete from the engine
try:
arcade = Arcade(api_key=config.api.key, base_url=engine_url)
arcade.workers.delete(worker_id)
except NotFoundError:
console.print("Worker not found", style="bold red")
except Exception as e:
console.print(f"Error deleting worker from engine: {e}", style="bold red")
raise typer.Exit(code=1)
@app.command("logs", help="Get logs for a worker")
def worker_logs(
worker_id: str,
cloud_host: str = typer.Option(
PROD_CLOUD_HOST,
"--cloud-host",
"-c",
help="The Arcade Engine host.",
hidden=True,
),
cloud_port: int = typer.Option(
None,
"--cloud-port",
"-cp",
help="The port of the Arcade Engine host.",
hidden=True,
),
force_tls: bool = typer.Option(
False,
"--tls",
help="Whether to force TLS for the connection to the Arcade Engine.",
hidden=True,
),
force_no_tls: bool = typer.Option(
False,
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
hidden=True,
),
) -> None:
config = validate_and_get_config()
cloud_url = compute_base_url(force_tls, force_no_tls, cloud_host, cloud_port)
try:
with httpx.stream(
"GET",
f"{cloud_url}/api/v1/workers/logs/{worker_id}",
headers={"Authorization": f"Bearer {config.api.key}", "Accept": "text/event-stream"},
# allow the connection to stay open indefinitely
timeout=None, # noqa: S113
) as s:
for line in s.iter_lines():
if not line or "[DONE]" in line: # Skip empty lines
continue
if "event: error" in line:
console.print("Could not stream logs", style="bold red")
if line.startswith("data:"):
# Extract just the data portion after 'data:'
data = line[5:].strip() # Remove 'data:' prefix and whitespace
console.print(data, markup=False)
except Exception as e:
console.print(f"Error connecting to log stream: {e}", style="bold red")
raise typer.Exit(code=1)
def get_toolkits(client: Arcade, worker_id: str | None) -> str:
if worker_id is None:
return ""
try:
# Get tools for the given worker
tools = client.workers.tools(worker_id)
toolkits: list[str] = []
if not tools.items:
return ""
# Get toolkit names
for page in tools.iter_pages():
for tool in page.items:
if tool.toolkit.name not in toolkits:
toolkits.append(tool.toolkit.name)
return ", ".join(toolkits)
except NotFoundError:
return ""
except Exception as e:
console.print(f"Error getting worker tools: {e}", style="bold red")
raise typer.Exit(code=1)

View file

@ -0,0 +1,294 @@
import base64
import io
import os
import secrets
import tarfile
from pathlib import Path
from typing import Any
import httpx
import toml
from arcadepy import Arcade, NotFoundError
from httpx import Client
from packaging.requirements import Requirement
from pydantic import BaseModel, field_validator, model_validator
# Base class for versioned packages
class Package(BaseModel):
name: str
specifier: str | None = None
@classmethod
def from_requirement(cls, requirement_str: str) -> "Package":
req = Requirement(requirement_str)
return cls(name=req.name, specifier=str(req.specifier) if req.specifier else None)
# Base class for a list of packages
class Packages(BaseModel):
packages: list[Package]
# Convert string package i.e. "arcade>1.0.0" to a name and specifier
# Specifiers are currently unused
@field_validator("packages", mode="before")
@classmethod
def parse_package_requirements(cls, packages: list[str]) -> list[Package]:
"""Convert package requirement strings to Package objects."""
return [Package.from_requirement(pkg) for pkg in packages]
# Base class for a local package
class LocalPackage(BaseModel):
name: str
content: str
# Base class for a list of local packages
class LocalPackages(BaseModel):
packages: list[str]
# Custom repository configurations
class PackageRepository(Packages):
index: str
index_url: str
trusted_host: str
# Pypi is a special case of a package repository
class Pypi(PackageRepository):
index: str = "pypi"
index_url: str = "https://pypi.org/simple"
trusted_host: str = "pypi.org"
class Config(BaseModel):
id: str
enabled: bool = True
timeout: int = 30
retries: int = 3
secret: str
# Validate that the secret is a non-empty string and not 'dev'
@field_validator("secret")
@classmethod
def valid_secret(cls, v: str) -> str:
if v.strip("") == "" or v == "dev":
raise ValueError("Secret must be a non-empty string and not 'dev'")
return v
# Cloud request for deploying a worker
class Request(BaseModel):
name: str
secret: str
enabled: bool
timeout: int
retries: int
pypi: Pypi | None = None
custom_repositories: list[PackageRepository] | None = None
local_packages: list[LocalPackage] | None = None
def execute(self, cloud_client: Client, engine_client: Arcade) -> Any:
# Attempt to deploy worker to the cloud
try:
cloud_response = cloud_client.put(
str(cloud_client.base_url) + "/api/v1/workers",
json=self.model_dump(mode="json"),
timeout=120,
)
cloud_response.raise_for_status()
except httpx.ConnectError as e:
raise ValueError(f"Failed to connect to cloud: {e}")
except Exception:
msg = cloud_response.json().get("msg", f"{cloud_response.status_code}: Unknown error")
raise ValueError(f"Failed to start worker: {msg}")
try:
# Check if worker already exists
engine_client.workers.get(self.name)
engine_client.workers.update(
id=self.name,
enabled=self.enabled,
http={
"uri": cloud_response.json()["data"]["worker_endpoint"],
"secret": self.secret,
"timeout": self.timeout,
"retry": self.retries,
},
)
# If worker does not exist, create it
except NotFoundError:
engine_client.workers.create(
id=self.name,
enabled=self.enabled,
http={
"uri": cloud_response.json()["data"]["worker_endpoint"],
"secret": self.secret,
"timeout": self.timeout,
"retry": self.retries,
},
)
except Exception as e:
raise ValueError(f"Failed to add worker to engine: {e}")
return cloud_response.json()
class Worker(BaseModel):
toml_path: Path
config: Config
pypi_source: Pypi | None = None
custom_source: list[PackageRepository] | None = None
local_source: LocalPackages | None = None
def request(self) -> Request:
"""Convert Deployment to a Request object."""
self.validate_packages()
self.compress_local_packages()
return Request(
name=self.config.id,
secret=self.config.secret,
enabled=self.config.enabled,
timeout=self.config.timeout,
retries=self.config.retries,
pypi=self.pypi_source,
custom_repositories=self.custom_source,
local_packages=self.compress_local_packages(),
)
# Search for local packages and compress the source code to send
def compress_local_packages(self) -> list[LocalPackage] | None:
"""Compress local packages into a list of LocalPackage objects."""
if self.local_source is None:
return None
# Compress local packages into a list of LocalPackage objects
def process_package(package_path_str: str) -> LocalPackage:
package_path = self.toml_path.parent / package_path_str
if not package_path.exists():
raise FileNotFoundError(f"Local package not found: {package_path}")
if not package_path.is_dir():
raise FileNotFoundError(f"Local package is not a directory: {package_path}")
# Check that the package is a valid python package
if (
not (package_path / "pyproject.toml").is_file()
and not (package_path / "setup.py").is_file()
):
raise ValueError(
f"package '{package_path}' must contain a pyproject.toml or setup.py file"
)
# Compress the package into a byte stream and tar
byte_stream = io.BytesIO()
with tarfile.open(fileobj=byte_stream, mode="w:gz") as tar:
tar.add(package_path, arcname=package_path.name)
byte_stream.seek(0)
package_bytes = byte_stream.read()
package_bytes_b64 = base64.b64encode(package_bytes).decode("utf-8")
return LocalPackage(name=package_path.name, content=package_bytes_b64)
return list(map(process_package, self.local_source.packages))
# Validate that there are no duplicate packages for each worker
def validate_packages(self) -> None:
"""Validate packages."""
packages: list[str] = []
if self.pypi_source:
for pypi_package in self.pypi_source.packages:
packages.append(pypi_package.name)
if self.custom_source:
for repository in self.custom_source:
for package in repository.packages:
packages.append(package.name)
if self.local_source:
for local_package in self.local_source.packages:
packages.append(os.path.normpath(Path(local_package)))
dupes = [x for n, x in enumerate(packages) if x in packages[:n]]
if dupes:
raise ValueError(f"Duplicate packages: {dupes}")
class Deployment(BaseModel):
toml_path: Path
worker: list[Worker]
# Validate that there are no duplicate worker names
@model_validator(mode="after")
def validate_workers(self) -> "Deployment":
for worker in self.worker:
if sum(worker.config.id == w.config.id for w in self.worker) > 1:
raise ValueError(f"Duplicate worker name: {worker.config.id}")
return self
# Load a deployment from a toml file
@classmethod
def from_toml(cls, toml_path: Path) -> "Deployment":
try:
with open(toml_path) as f:
toml_data = toml.load(f)
if not toml_data:
raise ValueError(f"Empty TOML file: {toml_path}")
# Add the toml path to each worker so relative packages can be found
if "worker" in toml_data:
for worker in toml_data["worker"]:
worker["toml_path"] = toml_path
return cls(**toml_data, toml_path=toml_path)
except toml.TomlDecodeError as e:
raise ValueError(f"Invalid TOML format in {toml_path}: {e!s}")
except FileNotFoundError:
raise FileNotFoundError(f"Config file not found: {toml_path}")
# Save the deployment to a toml file
def save(self) -> None:
print("writing deployment file", self.toml_path)
with open(self.toml_path, "w") as f:
data = self.model_dump()
# Remove the toml_path from the deployment file
del data["toml_path"]
for worker in data["worker"]:
del worker["toml_path"]
toml.dump(data, f)
# Create a demo deployment file with one worker
def create_demo_deployment(toml_path: Path, toolkit_name: str) -> None:
"""Create a deployment from a toml file."""
deployment = Deployment(
toml_path=toml_path,
worker=[
Worker(
toml_path=toml_path,
config=Config(
id="demo-worker",
enabled=True,
timeout=30,
retries=3,
secret=secrets.token_hex(16),
),
local_source=LocalPackages(packages=[f"./{toolkit_name}"]),
)
],
)
deployment.save()
# Get a currently existing deployment and add an additional local package
def update_deployment_with_local_packages(toml_path: Path, toolkit_name: str) -> None:
"""Update a deployment from a toml file."""
deployment = Deployment.from_toml(toml_path)
if deployment.worker[0].local_source is None:
deployment.worker[0].local_source = LocalPackages(packages=[f"./{toolkit_name}"])
else:
deployment.worker[0].local_source.packages.append(f"./{toolkit_name}")
deployment.save()

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "arcade-ai"
version = "1.0.5"
version = "1.1.0"
description = "Arcade Python SDK and CLI"
readme = "README.md"
packages = [
@ -21,12 +21,14 @@ rich = "^13.7.1"
Jinja2 = ">=3.1.5,<4.0.0"
pyyaml = "^6.0"
openai = "^1.36.0" # TODO: relax to an earlier version that still has what we need
arcadepy = "^1.0.0"
arcadepy = "^1.3.1"
pyjwt = "^2.8.0"
loguru = "^0.7.0"
tqdm = "^4.1.0"
toml = "^0.10.2"
types-python-dateutil = "2.9.0.20241003"
types-pytz = "2024.2.0.20241003"
types-toml = "0.10.8.20240310"
opentelemetry-instrumentation-fastapi = "0.48b0"
opentelemetry-exporter-otlp-proto-http = "1.27.0"
opentelemetry-exporter-otlp-proto-common = "1.27.0"
@ -42,6 +44,7 @@ pyreadline3 = {version = "^3.5.4", platform = "win32"}
[tool.poetry.extras]
evals = ["scipy", "numpy", "scikit-learn", "pytz", "python-dateutil"]
[tool.poetry.group.dev.dependencies]
pytest = "^8.1.2"
pytest-cov = "^4.0.0"

View file

@ -1,6 +1,6 @@
import pytest
from arcade.cli.utils import compute_engine_base_url, compute_login_url
from arcade.cli.utils import compute_base_url, compute_login_url
DEFAULT_CLOUD_HOST = "cloud.arcade.dev"
DEFAULT_ENGINE_HOST = "api.arcade.dev"
@ -186,7 +186,7 @@ DEFAULT_FORCE_NO_TLS = False
],
)
def test_compute_base_url(inputs: dict, expected_output: str):
base_url = compute_engine_base_url(
base_url = compute_base_url(
inputs["force_tls"],
inputs["force_no_tls"],
inputs["host_input"],

View file

@ -0,0 +1,227 @@
# Ignore hardcoded secret linting
# ruff: noqa: S105
# ruff: noqa: S106
import json
from pathlib import Path
import pytest
from arcade.worker.config.deployment import (
Config,
Deployment,
LocalPackages,
Package,
PackageRepository,
Pypi,
Worker,
)
@pytest.fixture
def test_dir():
return Path(__file__).parent
def test_invalid_toml_path(test_dir):
with pytest.raises(FileNotFoundError):
Deployment.from_toml(test_dir / "test_files" / "invalid.toml")
def test_missing_fields(test_dir):
with pytest.raises(ValueError):
Deployment.from_toml(test_dir / "test_files" / "invalid.fields.worker.toml")
def test_deployment_parsing(test_dir):
config_path = test_dir / "test_files" / "full.worker.toml"
deployment = Deployment.from_toml(config_path)
# Test config section
assert deployment.worker[0].config.id == "test"
assert deployment.worker[0].config.enabled is True
assert deployment.worker[0].config.timeout == 10
assert deployment.worker[0].config.retries == 3
assert deployment.worker[0].config.secret == "test-secret"
# Test pypi section
assert deployment.worker[0].pypi_source.packages == [Package(name="arcade-x")]
# Test local_packages section
assert deployment.worker[0].local_source.packages == ["./mock_toolkit"]
# Test custom_repositories section
repo = deployment.worker[0].custom_source[0]
assert repo.index == "pypi"
assert repo.index_url == "https://pypi.org/simple"
assert repo.trusted_host == "pypi.org"
assert repo.packages == [Package(name="arcade-ai", specifier=">=1.0.0")]
repo = deployment.worker[0].custom_source[1]
assert repo.index == "pypi2"
assert repo.index_url == "https://pypi2.org/simple"
assert repo.trusted_host == "pypi2.org"
assert repo.packages == [Package(name="arcade-slack")]
def test_specifier():
from packaging.requirements import Requirement
req = Requirement("arcade-ai>=1.0.0")
assert req.name == "arcade-ai"
assert req.specifier == ">=1.0.0"
def test_deployment_dict(test_dir):
config_path = test_dir / "test_files" / "full.worker.toml"
deployment = Deployment.from_toml(config_path)
expected = json.loads("""{
"name": "test",
"secret": "test-secret",
"enabled": true,
"timeout": 10,
"retries": 3,
"pypi": {
"packages": [
{
"name": "arcade-x",
"specifier": null
}
],
"index": "pypi",
"index_url": "https://pypi.org/simple",
"trusted_host": "pypi.org"
},
"custom_repositories": [
{
"packages": [
{
"name": "arcade-ai",
"specifier": ">=1.0.0"
}
],
"index": "pypi",
"index_url": "https://pypi.org/simple",
"trusted_host": "pypi.org"
},
{
"packages": [
{
"name": "arcade-slack",
"specifier": null
}
],
"index": "pypi2",
"index_url": "https://pypi2.org/simple",
"trusted_host": "pypi2.org"
}
],
"local_packages": [
{
"name": "mock_toolkit",
"content": "H4sIAOgdymcC/+2XwWuDMBTGPftXZDltMNIkJtrCOrpbL4PdSxmiKXNVIzHt6n+/OAvtNrqbMur7Xd7j5YGH5Ps+JBMyWbzEh6WKU2W8XqAdlyqlgTj17ZxRzriHDt4A7GobG/d5b5zwKSpsVqg5iwRjs6kMBJEzMYtC7nvA1VPoZPtqtc63mZ14/ek/krKrYVcp/655JtyLY4wHNHL6D5iMPCSH1H+dGtX84YBubbO5vvsn4P/g/+f+LyihPBSUSfD/sfl/EWclqZo+9B8Kcdn/eXTyf+bmTAjp9E+H1P9I/b8yWWlv8VLlub5HH9rk6Q2+A+mPhf+R/8Hv/GeQ/4Pkf/Qj/3lEpAgCOQUPGF3+V01l9LtKLLG6yAfLf07F2f9fq/+QhhTyfwhW7d2TSitrmrVfxoVCc4TPXwX298rUmS7bA0oYodhPVZ2YrLLH6bNbR8d1tNEGPZnExQn2451906Z2OyvczdBDqvaL+Ksnrn3EazAaAAAAAAAAAAAAAAAAAOiJT7MTVu0AKAAA"
}
]
}""")
got = deployment.worker[0].request().model_dump(mode="json")
print(got)
# Remove encoding part that contains the content
got["local_packages"][0].pop("content")
expected["local_packages"][0].pop("content")
assert got == expected
def test_invalid_secret_parsing(test_dir):
config_path = test_dir / "test_files" / "invalid.secret.worker.toml"
with pytest.raises(ValueError):
Deployment.from_toml(config_path)
def test_missing_local_package(test_dir):
config_path = test_dir / "test_files" / "invalid.localfile.worker.toml"
deployment = Deployment.from_toml(config_path)
with pytest.raises(FileNotFoundError):
deployment.worker[0].request()
def test_invalid_local_package(test_dir):
config_path = test_dir / "test_files" / "invalid.localfile.worker.toml"
deployment = Deployment.from_toml(config_path)
with pytest.raises(FileNotFoundError):
deployment.worker[1].request()
def test_unconfigured_local_package(test_dir):
config_path = test_dir / "test_files" / "invalid.localfile.worker.toml"
deployment = Deployment.from_toml(config_path)
with pytest.raises(ValueError):
deployment.worker[2].request()
def test_duplicate_pypi_packages():
worker = Worker(
toml_path=Path(__file__),
config=Config(id="test", secret="test-secret"),
pypi_source=Pypi(packages=["arcade-slack", "arcade-slack"]),
)
with pytest.raises(ValueError):
worker.validate_packages()
def test_duplicate_custom_repository_packages():
worker = Worker(
toml_path=Path(__file__),
config=Config(id="test", secret="test-secret"),
custom_source=[
PackageRepository(
index="pypi",
index_url="https://pypi.org/simple",
trusted_host="pypi.org",
packages=["arcade-slack", "arcade-slack"],
)
],
)
with pytest.raises(ValueError):
worker.validate_packages()
def test_duplicate_local_packages():
worker = Worker(
toml_path=Path(__file__),
config=Config(id="test", secret="test-secret"),
local_source=LocalPackages(packages=["./mock_toolkit", "./mock_toolkit"]),
)
with pytest.raises(ValueError):
worker.validate_packages()
def test_duplicate_all_typed_packages():
worker = Worker(
toml_path=Path(__file__),
config=Config(id="test", secret="test-secret"),
pypi_source=Pypi(packages=["arcade-slack"]),
custom_source=[
PackageRepository(
index="pypi",
index_url="https://pypi.org/simple",
trusted_host="pypi.org",
packages=["arcade-slack", "arcade-x"],
)
],
local_source=LocalPackages(packages=["./arcade-x"]),
)
with pytest.raises(ValueError):
worker.validate_packages()
def test_duplicate_worker_names():
worker = Worker(
toml_path=Path(__file__),
config=Config(id="test", secret="test-secret"),
)
worker2 = Worker(
toml_path=Path(__file__),
config=Config(id="test", secret="test-secret"),
)
with pytest.raises(ValueError):
Deployment(workers=[worker, worker2])

View file

@ -0,0 +1,26 @@
[[worker]]
[worker.config]
id = "test"
enabled = true
timeout = 10
retries = 3
secret = "test-secret"
[worker.pypi_source]
packages = ["arcade-x"]
[worker.local_source]
packages = ["./mock_toolkit"]
[[worker.custom_source]]
index = "pypi"
index_url = "https://pypi.org/simple"
trusted_host = "pypi.org"
packages = ["arcade-ai>=1.0.0"]
[[worker.custom_source]]
index = "pypi2"
index_url = "https://pypi2.org/simple"
trusted_host = "pypi2.org"
packages = ["arcade-slack"]

View file

@ -0,0 +1,3 @@
[[worker]]
[worker.config]

View file

@ -0,0 +1,42 @@
[[worker]]
[worker.config]
id = "test"
enabled = true
timeout = 10
retries = 3
secret = "test-secret"
[worker.pypi_source]
packages = ["arcade-ai"]
[worker.local_source]
packages = ["./missing_toolkit"]
[[worker]]
[worker.config]
id = "test-2"
enabled = true
timeout = 10
retries = 3
secret = "test-secret"
[worker.pypi_source]
packages = ["arcade-ai"]
[worker.local_source]
packages = ["./invalid.localfile.worker.toml"]
[[worker]]
[worker.config]
id = "test-3"
enabled = true
timeout = 10
retries = 3
secret = "test-secret"
[worker.pypi_source]
packages = ["arcade-ai"]
[worker.local_source]
packages = ["./invalid_toolkit"]

View file

@ -0,0 +1,7 @@
[[worker]]
[worker.config]
id = "test"
enabled = true
timeout = 10
retries = 3
secret = "dev"

View file

@ -0,0 +1 @@
print("Hello, world!")

View file

@ -0,0 +1,5 @@
[tool.poetry]
name = "mock_toolkit"
version = "0.1.0"
description = "Mock toolkit for Arcade"
authors = ["Arcade <dev@arcade.dev>"]

14
worker.toml Normal file
View file

@ -0,0 +1,14 @@
### Worker 1
[[worker]]
[worker.config]
id = "worker-1"
enabled = true
timeout = 10
retries = 3
secret = "test-secret"
[worker.pypi_source]
packages = ["arcade-slack", "arcade-x", "arcade-github"]
[worker.local_source]
packages = ["./toolkits/spotify"]