arcade chat: allow overriding host, port, TLS (#31)
Adds: - New options to `arcade chat`: `-h/--host`, `-p/--port`, and `--tls/--notls`. This allows us to point `arcade chat` at a different server than what's configured in `arcade.toml` which is very helpful for debugging. - Special case: if you do `-h localhost`, it will automatically use port 9099 and no TLS unless otherwise specified. - Adds a non-fatal engine health check to `arcade chat` startup: <img width="499" alt="image" src="https://github.com/user-attachments/assets/b7fae29e-2f8d-4004-a27b-645b4cd997a8">
This commit is contained in:
parent
d12542db55
commit
75c6a2becf
9 changed files with 657 additions and 250 deletions
|
|
@ -16,12 +16,14 @@ from rich.text import Text
|
|||
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
|
||||
from arcade.cli.utils import (
|
||||
OrderCommands,
|
||||
apply_config_overrides,
|
||||
create_cli_catalog,
|
||||
display_streamed_markdown,
|
||||
markdownify_urls,
|
||||
validate_and_get_config,
|
||||
)
|
||||
from arcade.client import Arcade
|
||||
from arcade.client.errors import EngineNotHealthyError, EngineOfflineError
|
||||
|
||||
cli = typer.Typer(
|
||||
cls=OrderCommands,
|
||||
|
|
@ -30,7 +32,14 @@ console = Console()
|
|||
|
||||
|
||||
@cli.command(help="Log in to Arcade Cloud")
|
||||
def login() -> None:
|
||||
def login(
|
||||
host: str = typer.Option(
|
||||
"https://cloud.arcade-ai.com",
|
||||
"-h",
|
||||
"--host",
|
||||
help="The Arcade Cloud host to log in to.",
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Logs the user into Arcade Cloud.
|
||||
"""
|
||||
|
|
@ -48,8 +57,7 @@ def login() -> None:
|
|||
# Open the browser for user login
|
||||
callback_uri = "http://localhost:9905/callback"
|
||||
params = urlencode({"callback_uri": callback_uri, "state": state})
|
||||
# TODO: make this configurable
|
||||
login_url = f"http://localhost:8001/api/v1/auth/cli_login?{params}"
|
||||
login_url = f"https://{host}/api/v1/auth/cli_login?{params}"
|
||||
console.print("Opening a browser to log you in...")
|
||||
webbrowser.open(login_url)
|
||||
|
||||
|
|
@ -131,12 +139,42 @@ def chat(
|
|||
stream: bool = typer.Option(
|
||||
False, "-s", "--stream", is_flag=True, help="Stream the tool output."
|
||||
),
|
||||
host: str = typer.Option(
|
||||
None,
|
||||
"-h",
|
||||
"--host",
|
||||
help="The Arcade Engine address to send chat requests to.",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
None,
|
||||
"-p",
|
||||
"--port",
|
||||
help="The port of the Arcade Engine.",
|
||||
),
|
||||
force_tls: bool = typer.Option(
|
||||
False,
|
||||
"--tls",
|
||||
help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.",
|
||||
),
|
||||
force_no_tls: bool = typer.Option(
|
||||
False,
|
||||
"--no-tls",
|
||||
help="Whether to disable TLS for the connection to the Arcade Engine.",
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Chat with a language model.
|
||||
"""
|
||||
config = validate_and_get_config()
|
||||
|
||||
if not force_tls and not force_no_tls:
|
||||
tls_input = None
|
||||
elif force_no_tls:
|
||||
tls_input = False
|
||||
else:
|
||||
tls_input = True
|
||||
apply_config_overrides(config, host, port, tls_input)
|
||||
|
||||
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
|
||||
user_email = config.user.email if config.user else None
|
||||
user_attribution = f"({user_email})" if user_email else ""
|
||||
|
|
@ -153,10 +191,17 @@ def chat(
|
|||
),
|
||||
"\n",
|
||||
"\n",
|
||||
"Chatting with Arcade Engine at " + config.engine_url,
|
||||
"Chatting with Arcade Engine at ",
|
||||
(
|
||||
config.engine_url,
|
||||
"bold blue",
|
||||
),
|
||||
)
|
||||
console.print(chat_header)
|
||||
|
||||
# Try to hit /health endpoint on engine and warn if it is down
|
||||
log_engine_health(client)
|
||||
|
||||
while True:
|
||||
console.print(f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] ")
|
||||
|
||||
|
|
@ -277,6 +322,28 @@ def config(
|
|||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
def log_engine_health(client: Arcade) -> None:
|
||||
try:
|
||||
client.health.check()
|
||||
|
||||
except EngineNotHealthyError as e:
|
||||
console.print(
|
||||
"[bold][yellow]⚠️ Warning: "
|
||||
+ str(e)
|
||||
+ " ("
|
||||
+ "[/yellow]"
|
||||
+ "[red]"
|
||||
+ str(e.status_code)
|
||||
+ "[/red]"
|
||||
+ "[yellow])[/yellow][/bold]"
|
||||
)
|
||||
except EngineOfflineError:
|
||||
console.print(
|
||||
"⚠️ Warning: Arcade Engine was unreachable. (Is it running?)",
|
||||
style="bold yellow",
|
||||
)
|
||||
|
||||
|
||||
def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Display the configuration details as a table using Rich library.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
import typer
|
||||
from openai.resources.chat.completions import ChatCompletionChunk, Stream
|
||||
from rich.console import Console
|
||||
|
|
@ -8,11 +6,9 @@ from typer.core import TyperGroup
|
|||
from typer.models import Context
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.config_model import Config
|
||||
from arcade.core.toolkit import Toolkit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arcade.core.config import Config
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
|
|
@ -101,7 +97,7 @@ def validate_and_get_config(
|
|||
validate_engine: bool = True,
|
||||
validate_api: bool = True,
|
||||
validate_user: bool = True,
|
||||
) -> "Config":
|
||||
) -> Config:
|
||||
"""
|
||||
Validates the configuration, user, and returns the Config object
|
||||
"""
|
||||
|
|
@ -125,3 +121,32 @@ def validate_and_get_config(
|
|||
raise typer.Exit(code=1)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def apply_config_overrides(
|
||||
config: Config, host_input: str | None, port_input: int | None, tls_input: bool | None
|
||||
) -> None:
|
||||
"""
|
||||
Apply optional config overrides (passed by the user) to the config object.
|
||||
"""
|
||||
|
||||
if not config.engine:
|
||||
# Should not happen, validate_and_get_config ensures that `engine` is set
|
||||
raise ValueError("Engine configuration not found in config.")
|
||||
|
||||
# Special case for "localhost" and nothing else specified:
|
||||
# default to dev port and no TLS for convenience
|
||||
if host_input == "localhost":
|
||||
if port_input is None:
|
||||
port_input = 9099
|
||||
if tls_input is None:
|
||||
tls_input = False
|
||||
|
||||
if host_input:
|
||||
config.engine.host = host_input
|
||||
|
||||
if port_input is not None:
|
||||
config.engine.port = port_input
|
||||
|
||||
if tls_input is not None:
|
||||
config.engine.tls = tls_input
|
||||
|
|
|
|||
|
|
@ -10,11 +10,13 @@ from arcade.client.base import (
|
|||
BaseResource,
|
||||
SyncArcadeClient,
|
||||
)
|
||||
from arcade.client.errors import APIStatusError, EngineNotHealthyError, EngineOfflineError
|
||||
from arcade.client.schema import (
|
||||
AuthProvider,
|
||||
AuthRequest,
|
||||
AuthResponse,
|
||||
ExecuteToolResponse,
|
||||
HealthCheckResponse,
|
||||
)
|
||||
from arcade.core.schema import ToolDefinition
|
||||
|
||||
|
|
@ -143,6 +145,41 @@ class ToolResource(BaseResource[ClientT]):
|
|||
return AuthResponse(**data)
|
||||
|
||||
|
||||
class HealthResource(BaseResource[ClientT]):
|
||||
"""Health check resource."""
|
||||
|
||||
def check(self) -> None:
|
||||
"""
|
||||
Check the health of the Arcade Engine.
|
||||
Raises an error if the health check fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"/{API_VERSION}/health",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
except APIStatusError as e:
|
||||
raise EngineNotHealthyError(
|
||||
"Arcade Engine health check returned an unhealthy status code",
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
# Catches everything else including httpx.ConnectError (most common)
|
||||
raise EngineOfflineError(f"Arcade Engine was unreachable: {e}")
|
||||
|
||||
health_check_response = HealthCheckResponse(**data)
|
||||
|
||||
# Raise an error if the health payload is not `healthy: true`
|
||||
if health_check_response.healthy is not True:
|
||||
raise EngineNotHealthyError(
|
||||
"Arcade Engine health check was not healthy",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
|
||||
"""Asynchronous Authentication resource."""
|
||||
|
||||
|
|
@ -234,6 +271,41 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
|
|||
return AuthResponse(**data)
|
||||
|
||||
|
||||
class AsyncHealthResource(BaseResource[AsyncArcadeClient]):
|
||||
"""Asynchronous Health check resource."""
|
||||
|
||||
async def check(self) -> None:
|
||||
"""
|
||||
Check the health of the Arcade Engine.
|
||||
Raises an error if the health check fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"/{API_VERSION}/health",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
except APIStatusError as e:
|
||||
raise EngineNotHealthyError(
|
||||
"Arcade Engine health check returned an unhealthy status code",
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
# Catches everything else including httpx.ConnectError (most common)
|
||||
raise EngineOfflineError(f"Arcade Engine was unreachable: {e}")
|
||||
|
||||
health_check_response = HealthCheckResponse(**data)
|
||||
|
||||
# Raise an error if the health payload is not `healthy: true`
|
||||
if health_check_response.healthy is not True:
|
||||
raise EngineNotHealthyError(
|
||||
"Arcade Engine health check was not healthy",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
class Arcade(SyncArcadeClient):
|
||||
"""Synchronous Arcade client."""
|
||||
|
||||
|
|
@ -241,6 +313,7 @@ class Arcade(SyncArcadeClient):
|
|||
super().__init__(*args, **kwargs)
|
||||
self.auth: AuthResource = AuthResource(self)
|
||||
self.tool: ToolResource = ToolResource(self)
|
||||
self.health: HealthResource = HealthResource(self)
|
||||
chat_url = self._chat_url(self._base_url)
|
||||
self._openai_client = OpenAI(base_url=chat_url, api_key=self._api_key)
|
||||
|
||||
|
|
@ -266,6 +339,7 @@ class AsyncArcade(AsyncArcadeClient):
|
|||
super().__init__(*args, **kwargs)
|
||||
self.auth: AsyncAuthResource = AsyncAuthResource(self)
|
||||
self.tool: AsyncToolResource = AsyncToolResource(self)
|
||||
self.health: AsyncHealthResource = AsyncHealthResource(self)
|
||||
chat_url = self._chat_url(self._base_url)
|
||||
self._openai_client = AsyncOpenAI(base_url=chat_url, api_key=self._api_key)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,25 @@ class ArcadeError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class EngineOfflineError(ArcadeError):
|
||||
"""Raised when the Arcade Engine is offline."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EngineNotHealthyError(ArcadeError):
|
||||
"""Raised when the Arcade Engine is not healthy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
status_code: int,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class APIError(ArcadeError):
|
||||
"""Base class for API-related errors."""
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,13 @@ class AuthStatus(str, Enum):
|
|||
completed = "completed"
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""Response from a health check request."""
|
||||
|
||||
healthy: bool
|
||||
"""Whether the health check was successful."""
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""Response from an authorization request."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,233 +1,6 @@
|
|||
import ipaddress
|
||||
from functools import cached_property, lru_cache
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from functools import lru_cache
|
||||
|
||||
import idna
|
||||
import toml
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from arcade.core.env import settings
|
||||
|
||||
|
||||
class ApiConfig(BaseModel):
|
||||
"""
|
||||
Arcade API configuration.
|
||||
"""
|
||||
|
||||
key: str
|
||||
"""
|
||||
Arcade API key.
|
||||
"""
|
||||
|
||||
|
||||
class UserConfig(BaseModel):
|
||||
"""
|
||||
Arcade user configuration.
|
||||
"""
|
||||
|
||||
email: str | None = None
|
||||
"""
|
||||
User email.
|
||||
"""
|
||||
|
||||
|
||||
class EngineConfig(BaseModel):
|
||||
"""
|
||||
Arcade Engine configuration.
|
||||
"""
|
||||
|
||||
host: str = "api.arcade-ai.com"
|
||||
"""
|
||||
Arcade Engine host.
|
||||
"""
|
||||
port: int | None = None
|
||||
"""
|
||||
Arcade Engine port.
|
||||
"""
|
||||
tls: bool = True
|
||||
"""
|
||||
Whether to use TLS for the connection to Arcade Engine.
|
||||
"""
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""
|
||||
Configuration for Arcade.
|
||||
"""
|
||||
|
||||
api: ApiConfig
|
||||
"""
|
||||
Arcade API configuration.
|
||||
"""
|
||||
user: UserConfig | None = None
|
||||
"""
|
||||
Arcade user configuration.
|
||||
"""
|
||||
engine: EngineConfig | None = EngineConfig()
|
||||
"""
|
||||
Arcade Engine configuration.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config_dir_path(cls) -> Path:
|
||||
"""
|
||||
Get the path to the Arcade configuration directory.
|
||||
"""
|
||||
return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
|
||||
|
||||
@classmethod
|
||||
def get_config_file_path(cls) -> Path:
|
||||
"""
|
||||
Get the path to the Arcade configuration file.
|
||||
"""
|
||||
return cls.get_config_dir_path() / "arcade.toml"
|
||||
|
||||
@cached_property
|
||||
def engine_url(self) -> str:
|
||||
"""
|
||||
Get the cached URL of the Arcade Engine.
|
||||
|
||||
This property is cached after its first access to improve performance.
|
||||
The cache is automatically invalidated if any of the underlying data changes.
|
||||
|
||||
The port is included in the URL unless the host is a fully qualified domain name
|
||||
(excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and
|
||||
hostnames with underscores.
|
||||
|
||||
This property exists to provide a consistent and correctly formatted URL for
|
||||
connecting to the Arcade Engine, taking into account various configuration
|
||||
options and edge cases. It ensures that:
|
||||
|
||||
1. The correct protocol (http/https) is used based on the TLS setting.
|
||||
2. IPv4 and IPv6 addresses are properly formatted.
|
||||
3. Internationalized Domain Names (IDNs) are correctly encoded.
|
||||
4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately.
|
||||
5. Ports are included when necessary, respecting common conventions for FQDNs.
|
||||
6. Hostnames with underscores (common in development environments) are supported.
|
||||
7. Pre-existing port specifications in the host are respected.
|
||||
|
||||
The resulting URL is always suffixed with '/v1' to specify the API version.
|
||||
|
||||
Returns:
|
||||
str: The fully constructed URL for the Arcade Engine.
|
||||
|
||||
Raises:
|
||||
ValueError: If the engine configuration is missing or incomplete.
|
||||
"""
|
||||
if self.engine is None:
|
||||
raise ValueError("Configuration for Engine is not set in arcade.toml")
|
||||
if not self.engine.host:
|
||||
raise ValueError("Configuration for Engine host is not set in arcade.toml")
|
||||
|
||||
protocol = "https" if self.engine.tls else "http"
|
||||
|
||||
# Handle potential IDNs
|
||||
try:
|
||||
encoded_host = idna.encode(self.engine.host).decode("ascii")
|
||||
except idna.IDNAError:
|
||||
encoded_host = self.engine.host
|
||||
|
||||
# Check if the host is a valid IP address (IPv4 or IPv6)
|
||||
try:
|
||||
ipaddress.ip_address(encoded_host)
|
||||
is_ip = True
|
||||
except ValueError:
|
||||
is_ip = False
|
||||
|
||||
# Parse the host, handling potential IPv6 addresses
|
||||
host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host
|
||||
parsed_host = urlparse(f"//{host_for_parsing}")
|
||||
|
||||
# Check if the host is a fully qualified domain name (excluding IP addresses)
|
||||
is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc
|
||||
|
||||
# Handle hosts that might already include a port
|
||||
if ":" in parsed_host.netloc and not is_ip:
|
||||
host, existing_port = parsed_host.netloc.rsplit(":", 1)
|
||||
if existing_port.isdigit():
|
||||
return f"{protocol}://{parsed_host.netloc}/v1"
|
||||
|
||||
if is_fqdn and self.engine.port is None:
|
||||
return f"{protocol}://{encoded_host}/v1"
|
||||
elif self.engine.port is not None:
|
||||
return f"{protocol}://{encoded_host}:{self.engine.port}/v1"
|
||||
else:
|
||||
return f"{protocol}://{encoded_host}/v1"
|
||||
|
||||
@classmethod
|
||||
def ensure_config_dir_exists(cls) -> None:
|
||||
"""
|
||||
Create the configuration directory if it does not exist.
|
||||
"""
|
||||
config_dir = Config.get_config_dir_path()
|
||||
if not config_dir.exists():
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls) -> "Config":
|
||||
"""
|
||||
Load the configuration from the TOML file in the configuration directory.
|
||||
|
||||
If no configuration file exists, this method will create a new one with default values.
|
||||
The default configuration includes:
|
||||
- An empty API configuration
|
||||
- A default Engine configuration (host: "api.arcade-ai.com", port: None, tls: True)
|
||||
- No user configuration
|
||||
|
||||
This behavior ensures that the application always has a valid configuration to work with,
|
||||
but it may not be suitable for all use cases. If a specific configuration is required,
|
||||
ensure that the configuration file exists before calling this method.
|
||||
|
||||
Returns:
|
||||
Config: The loaded or newly created configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If the existing configuration file is invalid.
|
||||
"""
|
||||
cls.ensure_config_dir_exists()
|
||||
|
||||
config_file_path = cls.get_config_file_path()
|
||||
if not config_file_path.exists():
|
||||
# Create a file using the default configuration
|
||||
default_config = cls.model_construct(
|
||||
api=ApiConfig.model_construct(), engine=EngineConfig()
|
||||
)
|
||||
default_config.save_to_file()
|
||||
|
||||
config_data = toml.loads(config_file_path.read_text())
|
||||
|
||||
try:
|
||||
return cls(**config_data)
|
||||
except ValidationError as e:
|
||||
# Get only the errors with {type:missing} and combine them
|
||||
# into a nicely-formatted string message.
|
||||
# Any other errors without {type:missing} should just be str()ed
|
||||
missing_field_errors = [
|
||||
".".join(map(str, error["loc"]))
|
||||
for error in e.errors()
|
||||
if error["type"] == "missing"
|
||||
]
|
||||
other_errors = [str(error) for error in e.errors() if error["type"] != "missing"]
|
||||
|
||||
missing_field_errors_str = ", ".join(missing_field_errors)
|
||||
other_errors_str = "\n".join(other_errors)
|
||||
|
||||
pretty_str: str = "Invalid Arcade configuration."
|
||||
if missing_field_errors_str:
|
||||
pretty_str += f"\nMissing fields: {missing_field_errors_str}\n"
|
||||
if other_errors_str:
|
||||
pretty_str += f"\nOther errors:\n{other_errors_str}"
|
||||
|
||||
raise ValueError(pretty_str) from e
|
||||
|
||||
def save_to_file(self) -> None:
|
||||
"""
|
||||
Save the configuration to the TOML file in the configuration directory.
|
||||
"""
|
||||
Config.ensure_config_dir_exists()
|
||||
config_file_path = Config.get_config_file_path()
|
||||
config_file_path.write_text(toml.dumps(self.model_dump()))
|
||||
from arcade.core.config_model import Config
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
|
|
|||
251
arcade/arcade/core/config_model.py
Normal file
251
arcade/arcade/core/config_model.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
import ipaddress
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import idna
|
||||
import toml
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from arcade.core.env import settings
|
||||
|
||||
|
||||
class ApiConfig(BaseModel):
|
||||
"""
|
||||
Arcade API configuration.
|
||||
"""
|
||||
|
||||
key: str
|
||||
"""
|
||||
Arcade API key.
|
||||
"""
|
||||
|
||||
|
||||
class UserConfig(BaseModel):
|
||||
"""
|
||||
Arcade user configuration.
|
||||
"""
|
||||
|
||||
email: str | None = None
|
||||
"""
|
||||
User email.
|
||||
"""
|
||||
|
||||
|
||||
class EngineConfig(BaseModel):
|
||||
"""
|
||||
Arcade Engine configuration.
|
||||
"""
|
||||
|
||||
host: str = "api.arcade-ai.com"
|
||||
"""
|
||||
Arcade Engine host.
|
||||
"""
|
||||
port: int | None = None
|
||||
"""
|
||||
Arcade Engine port.
|
||||
"""
|
||||
tls: bool = True
|
||||
"""
|
||||
Whether to use TLS for the connection to Arcade Engine.
|
||||
"""
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""
|
||||
Configuration for Arcade.
|
||||
"""
|
||||
|
||||
api: ApiConfig
|
||||
"""
|
||||
Arcade API configuration.
|
||||
"""
|
||||
user: UserConfig | None = None
|
||||
"""
|
||||
Arcade user configuration.
|
||||
"""
|
||||
engine: EngineConfig | None = EngineConfig()
|
||||
"""
|
||||
Arcade Engine configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
self._engine_url_cache: str | None = None
|
||||
self._engine_url_cache_key: str | None = None
|
||||
|
||||
@classmethod
|
||||
def get_config_dir_path(cls) -> Path:
|
||||
"""
|
||||
Get the path to the Arcade configuration directory.
|
||||
"""
|
||||
return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
|
||||
|
||||
@classmethod
|
||||
def get_config_file_path(cls) -> Path:
|
||||
"""
|
||||
Get the path to the Arcade configuration file.
|
||||
"""
|
||||
return cls.get_config_dir_path() / "arcade.toml"
|
||||
|
||||
def _generate_engine_url_cache_key(self) -> str:
|
||||
"""
|
||||
Generate a cache key for the engine_url property, based on its underlying data.
|
||||
"""
|
||||
if self.engine is None:
|
||||
return ""
|
||||
|
||||
return f"{self.engine.host}:{self.engine.port}:{self.engine.tls}"
|
||||
|
||||
@property
|
||||
def engine_url(self) -> str:
|
||||
"""
|
||||
Get the cached URL of the Arcade Engine.
|
||||
|
||||
This property is cached after its first access to improve performance.
|
||||
The cache is automatically invalidated if any of the underlying data changes.
|
||||
|
||||
The port is included in the URL unless the host is a fully qualified domain name
|
||||
(excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and
|
||||
hostnames with underscores.
|
||||
|
||||
This property exists to provide a consistent and correctly formatted URL for
|
||||
connecting to the Arcade Engine, taking into account various configuration
|
||||
options and edge cases. It ensures that:
|
||||
|
||||
1. The correct protocol (http/https) is used based on the TLS setting.
|
||||
2. IPv4 and IPv6 addresses are properly formatted.
|
||||
3. Internationalized Domain Names (IDNs) are correctly encoded.
|
||||
4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately.
|
||||
5. Ports are included when necessary, respecting common conventions for FQDNs.
|
||||
6. Hostnames with underscores (common in development environments) are supported.
|
||||
7. Pre-existing port specifications in the host are respected.
|
||||
|
||||
The resulting URL is always suffixed with '/v1' to specify the API version.
|
||||
|
||||
Returns:
|
||||
str: The fully constructed URL for the Arcade Engine.
|
||||
|
||||
Raises:
|
||||
ValueError: If the engine configuration is missing or incomplete.
|
||||
"""
|
||||
current_cache_key = self._generate_engine_url_cache_key()
|
||||
if self._engine_url_cache is None or self._engine_url_cache_key != current_cache_key:
|
||||
self._engine_url_cache = self._compute_engine_url()
|
||||
self._engine_url_cache_key = current_cache_key
|
||||
return self._engine_url_cache
|
||||
|
||||
def _compute_engine_url(self) -> str:
|
||||
if self.engine is None:
|
||||
raise ValueError("Configuration for Engine is not set in arcade.toml")
|
||||
if not self.engine.host:
|
||||
raise ValueError("Configuration for Engine host is not set in arcade.toml")
|
||||
|
||||
protocol = "https" if self.engine.tls else "http"
|
||||
|
||||
# Handle potential IDNs
|
||||
try:
|
||||
encoded_host = idna.encode(self.engine.host).decode("ascii")
|
||||
except idna.IDNAError:
|
||||
encoded_host = self.engine.host
|
||||
|
||||
# Check if the host is a valid IP address (IPv4 or IPv6)
|
||||
try:
|
||||
ipaddress.ip_address(encoded_host)
|
||||
is_ip = True
|
||||
except ValueError:
|
||||
is_ip = False
|
||||
|
||||
# Parse the host, handling potential IPv6 addresses
|
||||
host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host
|
||||
parsed_host = urlparse(f"//{host_for_parsing}")
|
||||
|
||||
# Check if the host is a fully qualified domain name (excluding IP addresses)
|
||||
is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc
|
||||
|
||||
# Handle hosts that might already include a port
|
||||
if ":" in parsed_host.netloc and not is_ip:
|
||||
host, existing_port = parsed_host.netloc.rsplit(":", 1)
|
||||
if existing_port.isdigit():
|
||||
return f"{protocol}://{parsed_host.netloc}/v1"
|
||||
|
||||
if is_fqdn and self.engine.port is None:
|
||||
return f"{protocol}://{encoded_host}/v1"
|
||||
elif self.engine.port is not None:
|
||||
return f"{protocol}://{encoded_host}:{self.engine.port}/v1"
|
||||
else:
|
||||
return f"{protocol}://{encoded_host}/v1"
|
||||
|
||||
@classmethod
|
||||
def ensure_config_dir_exists(cls) -> None:
|
||||
"""
|
||||
Create the configuration directory if it does not exist.
|
||||
"""
|
||||
config_dir = Config.get_config_dir_path()
|
||||
if not config_dir.exists():
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls) -> "Config":
|
||||
"""
|
||||
Load the configuration from the TOML file in the configuration directory.
|
||||
|
||||
If no configuration file exists, this method will create a new one with default values.
|
||||
The default configuration includes:
|
||||
- An empty API configuration
|
||||
- A default Engine configuration (host: "api.arcade-ai.com", port: None, tls: True)
|
||||
- No user configuration
|
||||
|
||||
This behavior ensures that the application always has a valid configuration to work with,
|
||||
but it may not be suitable for all use cases. If a specific configuration is required,
|
||||
ensure that the configuration file exists before calling this method.
|
||||
|
||||
Returns:
|
||||
Config: The loaded or newly created configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If the existing configuration file is invalid.
|
||||
"""
|
||||
cls.ensure_config_dir_exists()
|
||||
|
||||
config_file_path = cls.get_config_file_path()
|
||||
if not config_file_path.exists():
|
||||
# Create a file using the default configuration
|
||||
default_config = cls.model_construct(
|
||||
api=ApiConfig.model_construct(), engine=EngineConfig()
|
||||
)
|
||||
default_config.save_to_file()
|
||||
|
||||
config_data = toml.loads(config_file_path.read_text())
|
||||
|
||||
try:
|
||||
return cls(**config_data)
|
||||
except ValidationError as e:
|
||||
# Get only the errors with {type:missing} and combine them
|
||||
# into a nicely-formatted string message.
|
||||
# Any other errors without {type:missing} should just be str()ed
|
||||
missing_field_errors = [
|
||||
".".join(map(str, error["loc"]))
|
||||
for error in e.errors()
|
||||
if error["type"] == "missing"
|
||||
]
|
||||
other_errors = [str(error) for error in e.errors() if error["type"] != "missing"]
|
||||
|
||||
missing_field_errors_str = ", ".join(missing_field_errors)
|
||||
other_errors_str = "\n".join(other_errors)
|
||||
|
||||
pretty_str: str = "Invalid Arcade configuration."
|
||||
if missing_field_errors_str:
|
||||
pretty_str += f"\nMissing fields: {missing_field_errors_str}\n"
|
||||
if other_errors_str:
|
||||
pretty_str += f"\nOther errors:\n{other_errors_str}"
|
||||
|
||||
raise ValueError(pretty_str) from e
|
||||
|
||||
def save_to_file(self) -> None:
|
||||
"""
|
||||
Save the configuration to the TOML file in the configuration directory.
|
||||
"""
|
||||
Config.ensure_config_dir_exists()
|
||||
config_file_path = Config.get_config_file_path()
|
||||
config_file_path.write_text(toml.dumps(self.model_dump()))
|
||||
135
arcade/tests/cli/test_utils.py
Normal file
135
arcade/tests/cli/test_utils.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import pytest
|
||||
|
||||
from arcade.cli.utils import apply_config_overrides
|
||||
from arcade.core.config_model import ApiConfig, Config, EngineConfig
|
||||
|
||||
DEFAULT_HOST = "api.arcade-ai.com"
|
||||
DEFAULT_PORT = None
|
||||
DEFAULT_TLS = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs, expected_outputs",
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"host_input": None,
|
||||
"port_input": None,
|
||||
"tls_input": None,
|
||||
},
|
||||
{
|
||||
"host": DEFAULT_HOST,
|
||||
"port": DEFAULT_PORT,
|
||||
"tls": DEFAULT_TLS,
|
||||
},
|
||||
id="noop",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"host_input": "api2.arcade-ai.com",
|
||||
"port_input": None,
|
||||
"tls_input": None,
|
||||
},
|
||||
{
|
||||
"host": "api2.arcade-ai.com",
|
||||
"port": DEFAULT_PORT,
|
||||
"tls": DEFAULT_TLS,
|
||||
},
|
||||
id="set host",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"host_input": None,
|
||||
"port_input": 6789,
|
||||
"tls_input": None,
|
||||
},
|
||||
{
|
||||
"host": DEFAULT_HOST,
|
||||
"port": 6789,
|
||||
"tls": DEFAULT_TLS,
|
||||
},
|
||||
id="set port",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"host_input": None,
|
||||
"port_input": None,
|
||||
"tls_input": False,
|
||||
},
|
||||
{
|
||||
"host": DEFAULT_HOST,
|
||||
"port": DEFAULT_PORT,
|
||||
"tls": False,
|
||||
},
|
||||
id="set TLS to False",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"host_input": None,
|
||||
"port_input": None,
|
||||
"tls_input": True,
|
||||
},
|
||||
{
|
||||
"host": DEFAULT_HOST,
|
||||
"port": DEFAULT_PORT,
|
||||
"tls": True,
|
||||
},
|
||||
id="set TLS to True",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"host_input": "localhost",
|
||||
"port_input": None,
|
||||
"tls_input": None,
|
||||
},
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 9099,
|
||||
"tls": False,
|
||||
},
|
||||
id="localhost and no port or TLS specified",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"host_input": "localhost",
|
||||
"port_input": 1234,
|
||||
"tls_input": None,
|
||||
},
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 1234,
|
||||
"tls": False,
|
||||
},
|
||||
id="localhost and port specified",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"host_input": "localhost",
|
||||
"port_input": None,
|
||||
"tls_input": True,
|
||||
},
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 9099,
|
||||
"tls": True,
|
||||
},
|
||||
id="localhost and TLS specified",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_apply_config_overrides(inputs: dict, expected_outputs: dict):
|
||||
# Set fake default values for testing
|
||||
config = Config(
|
||||
api=ApiConfig(key="fake_api_key"),
|
||||
engine=EngineConfig(
|
||||
host=DEFAULT_HOST,
|
||||
port=DEFAULT_PORT,
|
||||
tls=DEFAULT_TLS,
|
||||
),
|
||||
)
|
||||
|
||||
apply_config_overrides(config, inputs["host_input"], inputs["port_input"], inputs["tls_input"])
|
||||
|
||||
assert config.engine.host == expected_outputs["host"]
|
||||
assert config.engine.port == expected_outputs["port"]
|
||||
assert config.engine.tls == expected_outputs["tls"]
|
||||
|
|
@ -6,6 +6,7 @@ from httpx import HTTPStatusError, Response
|
|||
from arcade.client import Arcade, AsyncArcade, AuthProvider
|
||||
from arcade.client.errors import (
|
||||
BadRequestError,
|
||||
EngineNotHealthyError,
|
||||
InternalServerError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
|
|
@ -51,6 +52,15 @@ TOOL_AUTHORIZE_RESPONSE_DATA = {
|
|||
"status": "pending",
|
||||
}
|
||||
|
||||
HEALTH_CHECK_HEALTHY_RESPONSE_DATA = {
|
||||
"healthy": True,
|
||||
}
|
||||
|
||||
HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA = {
|
||||
"healthy": False,
|
||||
"reason": "Cannot reticulate splines",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
|
|
@ -87,7 +97,7 @@ def test_handle_http_error(error_code, expected_error, mock_response):
|
|||
mock_http_error = Mock(spec=HTTPStatusError)
|
||||
mock_http_error.response = mock_response
|
||||
|
||||
client = Arcade() # Create an instance of Arcade
|
||||
client = Arcade(api_key="fake_api_key") # Create an instance of Arcade
|
||||
with pytest.raises(expected_error):
|
||||
client._handle_http_error(mock_http_error) # Call the method on the instance
|
||||
|
||||
|
|
@ -95,7 +105,7 @@ def test_handle_http_error(error_code, expected_error, mock_response):
|
|||
def test_arcade_auth_authorize(mock_response, monkeypatch):
|
||||
"""Test Arcade.auth.authorize method."""
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
|
||||
client = Arcade()
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
auth_response = client.auth.authorize(
|
||||
provider=AuthProvider.google,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
|
|
@ -107,7 +117,7 @@ def test_arcade_auth_authorize(mock_response, monkeypatch):
|
|||
def test_arcade_auth_poll_authorization(mock_response, monkeypatch):
|
||||
"""Test Arcade.auth.poll_authorization method."""
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
|
||||
client = Arcade()
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
auth_response = client.auth.status("auth_123")
|
||||
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
|
||||
|
||||
|
|
@ -115,7 +125,7 @@ def test_arcade_auth_poll_authorization(mock_response, monkeypatch):
|
|||
def test_arcade_tool_run(mock_response, monkeypatch):
|
||||
"""Test Arcade.tool.run method."""
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_RESPONSE_DATA)
|
||||
client = Arcade()
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
tool_response = client.tool.run(
|
||||
tool_name="GetEmails",
|
||||
user_id="sam@arcade-ai.com",
|
||||
|
|
@ -128,7 +138,7 @@ def test_arcade_tool_run(mock_response, monkeypatch):
|
|||
def test_arcade_tool_get(mock_response, monkeypatch):
|
||||
"""Test Arcade.tool.get method."""
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_DEFINITION_DATA)
|
||||
client = Arcade()
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
tool_definition = client.tool.get(director_id="default", tool_id="GetEmails")
|
||||
assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA)
|
||||
|
||||
|
|
@ -138,11 +148,31 @@ def test_arcade_tool_authorize(mock_response, monkeypatch):
|
|||
monkeypatch.setattr(
|
||||
Arcade, "_execute_request", lambda *args, **kwargs: TOOL_AUTHORIZE_RESPONSE_DATA
|
||||
)
|
||||
client = Arcade()
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
auth_response = client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com")
|
||||
assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA)
|
||||
|
||||
|
||||
def test_arcade_health_check(mock_response, monkeypatch):
|
||||
"""Test Arcade.health.check method."""
|
||||
monkeypatch.setattr(
|
||||
Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_HEALTHY_RESPONSE_DATA
|
||||
)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
client.health.check()
|
||||
assert True # If no exception is raised, the test passes
|
||||
|
||||
|
||||
def test_arcade_health_check_raises_error(mock_response, monkeypatch):
|
||||
"""Test Arcade.health.check method."""
|
||||
monkeypatch.setattr(
|
||||
Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA
|
||||
)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
with pytest.raises(EngineNotHealthyError):
|
||||
client.health.check()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.auth.authorize method."""
|
||||
|
|
@ -151,7 +181,7 @@ async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch):
|
|||
return AUTH_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade()
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
auth_response = await client.auth.authorize(
|
||||
provider=AuthProvider.google,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
|
|
@ -168,7 +198,7 @@ async def test_async_arcade_auth_poll_authorization(mock_async_response, monkeyp
|
|||
return AUTH_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade()
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
auth_response = await client.auth.status("auth_123")
|
||||
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
|
||||
|
||||
|
|
@ -181,7 +211,7 @@ async def test_async_arcade_tool_run(mock_async_response, monkeypatch):
|
|||
return TOOL_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade()
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
tool_response = await client.tool.run(
|
||||
tool_name="GetEmails",
|
||||
user_id="sam@arcade-ai.com",
|
||||
|
|
@ -199,7 +229,7 @@ async def test_async_arcade_tool_get(mock_async_response, monkeypatch):
|
|||
return TOOL_DEFINITION_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade()
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
tool_definition = await client.tool.get(director_id="default", tool_id="GetEmails")
|
||||
assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA)
|
||||
|
||||
|
|
@ -212,6 +242,32 @@ async def test_async_arcade_tool_authorize(mock_async_response, monkeypatch):
|
|||
return TOOL_AUTHORIZE_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade()
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
auth_response = await client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com")
|
||||
assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_health_check(mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.health.check method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return HEALTH_CHECK_HEALTHY_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
await client.health.check()
|
||||
assert True # If no exception is raised, the test passes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_health_check_raises_error(mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.health.check method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
with pytest.raises(EngineNotHealthyError):
|
||||
await client.health.check()
|
||||
|
|
|
|||
Loading…
Reference in a new issue