arcade chat: allow overriding host, port, TLS (#31)

Adds:
- New options to `arcade chat`: `-h/--host`, `-p/--port`, and
`--tls/--notls`. This allows us to point `arcade chat` at a different
server than what's configured in `arcade.toml` which is very helpful for
debugging.
- Special case: if you do `-h localhost`, it will automatically use port
9099 and no TLS unless otherwise specified.
- Adds a non-fatal engine health check to `arcade chat` startup:

<img width="499" alt="image"
src="https://github.com/user-attachments/assets/b7fae29e-2f8d-4004-a27b-645b4cd997a8">
This commit is contained in:
Nate Barbettini 2024-09-10 09:25:05 -07:00 committed by GitHub
parent d12542db55
commit 75c6a2becf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 657 additions and 250 deletions

View file

@ -16,12 +16,14 @@ from rich.text import Text
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
from arcade.cli.utils import (
OrderCommands,
apply_config_overrides,
create_cli_catalog,
display_streamed_markdown,
markdownify_urls,
validate_and_get_config,
)
from arcade.client import Arcade
from arcade.client.errors import EngineNotHealthyError, EngineOfflineError
cli = typer.Typer(
cls=OrderCommands,
@ -30,7 +32,14 @@ console = Console()
@cli.command(help="Log in to Arcade Cloud")
def login() -> None:
def login(
host: str = typer.Option(
"https://cloud.arcade-ai.com",
"-h",
"--host",
help="The Arcade Cloud host to log in to.",
),
) -> None:
"""
Logs the user into Arcade Cloud.
"""
@ -48,8 +57,7 @@ def login() -> None:
# Open the browser for user login
callback_uri = "http://localhost:9905/callback"
params = urlencode({"callback_uri": callback_uri, "state": state})
# TODO: make this configurable
login_url = f"http://localhost:8001/api/v1/auth/cli_login?{params}"
login_url = f"https://{host}/api/v1/auth/cli_login?{params}"
console.print("Opening a browser to log you in...")
webbrowser.open(login_url)
@ -131,12 +139,42 @@ def chat(
stream: bool = typer.Option(
False, "-s", "--stream", is_flag=True, help="Stream the tool output."
),
host: str = typer.Option(
None,
"-h",
"--host",
help="The Arcade Engine address to send chat requests to.",
),
port: int = typer.Option(
None,
"-p",
"--port",
help="The port of the Arcade Engine.",
),
force_tls: bool = typer.Option(
False,
"--tls",
help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.",
),
force_no_tls: bool = typer.Option(
False,
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
),
) -> None:
"""
Chat with a language model.
"""
config = validate_and_get_config()
if not force_tls and not force_no_tls:
tls_input = None
elif force_no_tls:
tls_input = False
else:
tls_input = True
apply_config_overrides(config, host, port, tls_input)
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
user_email = config.user.email if config.user else None
user_attribution = f"({user_email})" if user_email else ""
@ -153,10 +191,17 @@ def chat(
),
"\n",
"\n",
"Chatting with Arcade Engine at " + config.engine_url,
"Chatting with Arcade Engine at ",
(
config.engine_url,
"bold blue",
),
)
console.print(chat_header)
# Try to hit /health endpoint on engine and warn if it is down
log_engine_health(client)
while True:
console.print(f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] ")
@ -277,6 +322,28 @@ def config(
raise typer.Exit(code=1)
def log_engine_health(client: Arcade) -> None:
try:
client.health.check()
except EngineNotHealthyError as e:
console.print(
"[bold][yellow]⚠️ Warning: "
+ str(e)
+ " ("
+ "[/yellow]"
+ "[red]"
+ str(e.status_code)
+ "[/red]"
+ "[yellow])[/yellow][/bold]"
)
except EngineOfflineError:
console.print(
"⚠️ Warning: Arcade Engine was unreachable. (Is it running?)",
style="bold yellow",
)
def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
"""
Display the configuration details as a table using Rich library.

View file

@ -1,5 +1,3 @@
from typing import TYPE_CHECKING
import typer
from openai.resources.chat.completions import ChatCompletionChunk, Stream
from rich.console import Console
@ -8,11 +6,9 @@ from typer.core import TyperGroup
from typer.models import Context
from arcade.core.catalog import ToolCatalog
from arcade.core.config_model import Config
from arcade.core.toolkit import Toolkit
if TYPE_CHECKING:
from arcade.core.config import Config
console = Console()
@ -101,7 +97,7 @@ def validate_and_get_config(
validate_engine: bool = True,
validate_api: bool = True,
validate_user: bool = True,
) -> "Config":
) -> Config:
"""
Validates the configuration, user, and returns the Config object
"""
@ -125,3 +121,32 @@ def validate_and_get_config(
raise typer.Exit(code=1)
return config
def apply_config_overrides(
config: Config, host_input: str | None, port_input: int | None, tls_input: bool | None
) -> None:
"""
Apply optional config overrides (passed by the user) to the config object.
"""
if not config.engine:
# Should not happen, validate_and_get_config ensures that `engine` is set
raise ValueError("Engine configuration not found in config.")
# Special case for "localhost" and nothing else specified:
# default to dev port and no TLS for convenience
if host_input == "localhost":
if port_input is None:
port_input = 9099
if tls_input is None:
tls_input = False
if host_input:
config.engine.host = host_input
if port_input is not None:
config.engine.port = port_input
if tls_input is not None:
config.engine.tls = tls_input

View file

@ -10,11 +10,13 @@ from arcade.client.base import (
BaseResource,
SyncArcadeClient,
)
from arcade.client.errors import APIStatusError, EngineNotHealthyError, EngineOfflineError
from arcade.client.schema import (
AuthProvider,
AuthRequest,
AuthResponse,
ExecuteToolResponse,
HealthCheckResponse,
)
from arcade.core.schema import ToolDefinition
@ -143,6 +145,41 @@ class ToolResource(BaseResource[ClientT]):
return AuthResponse(**data)
class HealthResource(BaseResource[ClientT]):
"""Health check resource."""
def check(self) -> None:
"""
Check the health of the Arcade Engine.
Raises an error if the health check fails.
"""
try:
data = self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"/{API_VERSION}/health",
timeout=5,
)
except APIStatusError as e:
raise EngineNotHealthyError(
"Arcade Engine health check returned an unhealthy status code",
status_code=e.status_code,
)
except Exception as e:
# Catches everything else including httpx.ConnectError (most common)
raise EngineOfflineError(f"Arcade Engine was unreachable: {e}")
health_check_response = HealthCheckResponse(**data)
# Raise an error if the health payload is not `healthy: true`
if health_check_response.healthy is not True:
raise EngineNotHealthyError(
"Arcade Engine health check was not healthy",
status_code=200,
)
class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
"""Asynchronous Authentication resource."""
@ -234,6 +271,41 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
return AuthResponse(**data)
class AsyncHealthResource(BaseResource[AsyncArcadeClient]):
"""Asynchronous Health check resource."""
async def check(self) -> None:
"""
Check the health of the Arcade Engine.
Raises an error if the health check fails.
"""
try:
data = await self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"/{API_VERSION}/health",
timeout=5,
)
except APIStatusError as e:
raise EngineNotHealthyError(
"Arcade Engine health check returned an unhealthy status code",
status_code=e.status_code,
)
except Exception as e:
# Catches everything else including httpx.ConnectError (most common)
raise EngineOfflineError(f"Arcade Engine was unreachable: {e}")
health_check_response = HealthCheckResponse(**data)
# Raise an error if the health payload is not `healthy: true`
if health_check_response.healthy is not True:
raise EngineNotHealthyError(
"Arcade Engine health check was not healthy",
status_code=200,
)
class Arcade(SyncArcadeClient):
"""Synchronous Arcade client."""
@ -241,6 +313,7 @@ class Arcade(SyncArcadeClient):
super().__init__(*args, **kwargs)
self.auth: AuthResource = AuthResource(self)
self.tool: ToolResource = ToolResource(self)
self.health: HealthResource = HealthResource(self)
chat_url = self._chat_url(self._base_url)
self._openai_client = OpenAI(base_url=chat_url, api_key=self._api_key)
@ -266,6 +339,7 @@ class AsyncArcade(AsyncArcadeClient):
super().__init__(*args, **kwargs)
self.auth: AsyncAuthResource = AsyncAuthResource(self)
self.tool: AsyncToolResource = AsyncToolResource(self)
self.health: AsyncHealthResource = AsyncHealthResource(self)
chat_url = self._chat_url(self._base_url)
self._openai_client = AsyncOpenAI(base_url=chat_url, api_key=self._api_key)

View file

@ -9,6 +9,25 @@ class ArcadeError(Exception):
pass
class EngineOfflineError(ArcadeError):
"""Raised when the Arcade Engine is offline."""
def __init__(self, message: str):
super().__init__(message)
class EngineNotHealthyError(ArcadeError):
"""Raised when the Arcade Engine is not healthy."""
def __init__(
self,
message: str,
status_code: int,
):
super().__init__(message)
self.status_code = status_code
class APIError(ArcadeError):
"""Base class for API-related errors."""

View file

@ -42,6 +42,13 @@ class AuthStatus(str, Enum):
completed = "completed"
class HealthCheckResponse(BaseModel):
"""Response from a health check request."""
healthy: bool
"""Whether the health check was successful."""
class AuthResponse(BaseModel):
"""Response from an authorization request."""

View file

@ -1,233 +1,6 @@
import ipaddress
from functools import cached_property, lru_cache
from pathlib import Path
from urllib.parse import urlparse
from functools import lru_cache
import idna
import toml
from pydantic import BaseModel, ValidationError
from arcade.core.env import settings
class ApiConfig(BaseModel):
"""
Arcade API configuration.
"""
key: str
"""
Arcade API key.
"""
class UserConfig(BaseModel):
"""
Arcade user configuration.
"""
email: str | None = None
"""
User email.
"""
class EngineConfig(BaseModel):
"""
Arcade Engine configuration.
"""
host: str = "api.arcade-ai.com"
"""
Arcade Engine host.
"""
port: int | None = None
"""
Arcade Engine port.
"""
tls: bool = True
"""
Whether to use TLS for the connection to Arcade Engine.
"""
class Config(BaseModel):
"""
Configuration for Arcade.
"""
api: ApiConfig
"""
Arcade API configuration.
"""
user: UserConfig | None = None
"""
Arcade user configuration.
"""
engine: EngineConfig | None = EngineConfig()
"""
Arcade Engine configuration.
"""
@classmethod
def get_config_dir_path(cls) -> Path:
"""
Get the path to the Arcade configuration directory.
"""
return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
@classmethod
def get_config_file_path(cls) -> Path:
"""
Get the path to the Arcade configuration file.
"""
return cls.get_config_dir_path() / "arcade.toml"
@cached_property
def engine_url(self) -> str:
"""
Get the cached URL of the Arcade Engine.
This property is cached after its first access to improve performance.
The cache is automatically invalidated if any of the underlying data changes.
The port is included in the URL unless the host is a fully qualified domain name
(excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and
hostnames with underscores.
This property exists to provide a consistent and correctly formatted URL for
connecting to the Arcade Engine, taking into account various configuration
options and edge cases. It ensures that:
1. The correct protocol (http/https) is used based on the TLS setting.
2. IPv4 and IPv6 addresses are properly formatted.
3. Internationalized Domain Names (IDNs) are correctly encoded.
4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately.
5. Ports are included when necessary, respecting common conventions for FQDNs.
6. Hostnames with underscores (common in development environments) are supported.
7. Pre-existing port specifications in the host are respected.
The resulting URL is always suffixed with '/v1' to specify the API version.
Returns:
str: The fully constructed URL for the Arcade Engine.
Raises:
ValueError: If the engine configuration is missing or incomplete.
"""
if self.engine is None:
raise ValueError("Configuration for Engine is not set in arcade.toml")
if not self.engine.host:
raise ValueError("Configuration for Engine host is not set in arcade.toml")
protocol = "https" if self.engine.tls else "http"
# Handle potential IDNs
try:
encoded_host = idna.encode(self.engine.host).decode("ascii")
except idna.IDNAError:
encoded_host = self.engine.host
# Check if the host is a valid IP address (IPv4 or IPv6)
try:
ipaddress.ip_address(encoded_host)
is_ip = True
except ValueError:
is_ip = False
# Parse the host, handling potential IPv6 addresses
host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host
parsed_host = urlparse(f"//{host_for_parsing}")
# Check if the host is a fully qualified domain name (excluding IP addresses)
is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc
# Handle hosts that might already include a port
if ":" in parsed_host.netloc and not is_ip:
host, existing_port = parsed_host.netloc.rsplit(":", 1)
if existing_port.isdigit():
return f"{protocol}://{parsed_host.netloc}/v1"
if is_fqdn and self.engine.port is None:
return f"{protocol}://{encoded_host}/v1"
elif self.engine.port is not None:
return f"{protocol}://{encoded_host}:{self.engine.port}/v1"
else:
return f"{protocol}://{encoded_host}/v1"
@classmethod
def ensure_config_dir_exists(cls) -> None:
"""
Create the configuration directory if it does not exist.
"""
config_dir = Config.get_config_dir_path()
if not config_dir.exists():
config_dir.mkdir(parents=True, exist_ok=True)
@classmethod
def load_from_file(cls) -> "Config":
"""
Load the configuration from the TOML file in the configuration directory.
If no configuration file exists, this method will create a new one with default values.
The default configuration includes:
- An empty API configuration
- A default Engine configuration (host: "api.arcade-ai.com", port: None, tls: True)
- No user configuration
This behavior ensures that the application always has a valid configuration to work with,
but it may not be suitable for all use cases. If a specific configuration is required,
ensure that the configuration file exists before calling this method.
Returns:
Config: The loaded or newly created configuration.
Raises:
ValueError: If the existing configuration file is invalid.
"""
cls.ensure_config_dir_exists()
config_file_path = cls.get_config_file_path()
if not config_file_path.exists():
# Create a file using the default configuration
default_config = cls.model_construct(
api=ApiConfig.model_construct(), engine=EngineConfig()
)
default_config.save_to_file()
config_data = toml.loads(config_file_path.read_text())
try:
return cls(**config_data)
except ValidationError as e:
# Get only the errors with {type:missing} and combine them
# into a nicely-formatted string message.
# Any other errors without {type:missing} should just be str()ed
missing_field_errors = [
".".join(map(str, error["loc"]))
for error in e.errors()
if error["type"] == "missing"
]
other_errors = [str(error) for error in e.errors() if error["type"] != "missing"]
missing_field_errors_str = ", ".join(missing_field_errors)
other_errors_str = "\n".join(other_errors)
pretty_str: str = "Invalid Arcade configuration."
if missing_field_errors_str:
pretty_str += f"\nMissing fields: {missing_field_errors_str}\n"
if other_errors_str:
pretty_str += f"\nOther errors:\n{other_errors_str}"
raise ValueError(pretty_str) from e
def save_to_file(self) -> None:
"""
Save the configuration to the TOML file in the configuration directory.
"""
Config.ensure_config_dir_exists()
config_file_path = Config.get_config_file_path()
config_file_path.write_text(toml.dumps(self.model_dump()))
from arcade.core.config_model import Config
@lru_cache(maxsize=1)

View file

@ -0,0 +1,251 @@
import ipaddress
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import idna
import toml
from pydantic import BaseModel, ValidationError
from arcade.core.env import settings
class ApiConfig(BaseModel):
"""
Arcade API configuration.
"""
key: str
"""
Arcade API key.
"""
class UserConfig(BaseModel):
"""
Arcade user configuration.
"""
email: str | None = None
"""
User email.
"""
class EngineConfig(BaseModel):
"""
Arcade Engine configuration.
"""
host: str = "api.arcade-ai.com"
"""
Arcade Engine host.
"""
port: int | None = None
"""
Arcade Engine port.
"""
tls: bool = True
"""
Whether to use TLS for the connection to Arcade Engine.
"""
class Config(BaseModel):
"""
Configuration for Arcade.
"""
api: ApiConfig
"""
Arcade API configuration.
"""
user: UserConfig | None = None
"""
Arcade user configuration.
"""
engine: EngineConfig | None = EngineConfig()
"""
Arcade Engine configuration.
"""
def __init__(self, **data: Any):
super().__init__(**data)
self._engine_url_cache: str | None = None
self._engine_url_cache_key: str | None = None
@classmethod
def get_config_dir_path(cls) -> Path:
"""
Get the path to the Arcade configuration directory.
"""
return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
@classmethod
def get_config_file_path(cls) -> Path:
"""
Get the path to the Arcade configuration file.
"""
return cls.get_config_dir_path() / "arcade.toml"
def _generate_engine_url_cache_key(self) -> str:
"""
Generate a cache key for the engine_url property, based on its underlying data.
"""
if self.engine is None:
return ""
return f"{self.engine.host}:{self.engine.port}:{self.engine.tls}"
@property
def engine_url(self) -> str:
"""
Get the cached URL of the Arcade Engine.
This property is cached after its first access to improve performance.
The cache is automatically invalidated if any of the underlying data changes.
The port is included in the URL unless the host is a fully qualified domain name
(excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and
hostnames with underscores.
This property exists to provide a consistent and correctly formatted URL for
connecting to the Arcade Engine, taking into account various configuration
options and edge cases. It ensures that:
1. The correct protocol (http/https) is used based on the TLS setting.
2. IPv4 and IPv6 addresses are properly formatted.
3. Internationalized Domain Names (IDNs) are correctly encoded.
4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately.
5. Ports are included when necessary, respecting common conventions for FQDNs.
6. Hostnames with underscores (common in development environments) are supported.
7. Pre-existing port specifications in the host are respected.
The resulting URL is always suffixed with '/v1' to specify the API version.
Returns:
str: The fully constructed URL for the Arcade Engine.
Raises:
ValueError: If the engine configuration is missing or incomplete.
"""
current_cache_key = self._generate_engine_url_cache_key()
if self._engine_url_cache is None or self._engine_url_cache_key != current_cache_key:
self._engine_url_cache = self._compute_engine_url()
self._engine_url_cache_key = current_cache_key
return self._engine_url_cache
def _compute_engine_url(self) -> str:
if self.engine is None:
raise ValueError("Configuration for Engine is not set in arcade.toml")
if not self.engine.host:
raise ValueError("Configuration for Engine host is not set in arcade.toml")
protocol = "https" if self.engine.tls else "http"
# Handle potential IDNs
try:
encoded_host = idna.encode(self.engine.host).decode("ascii")
except idna.IDNAError:
encoded_host = self.engine.host
# Check if the host is a valid IP address (IPv4 or IPv6)
try:
ipaddress.ip_address(encoded_host)
is_ip = True
except ValueError:
is_ip = False
# Parse the host, handling potential IPv6 addresses
host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host
parsed_host = urlparse(f"//{host_for_parsing}")
# Check if the host is a fully qualified domain name (excluding IP addresses)
is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc
# Handle hosts that might already include a port
if ":" in parsed_host.netloc and not is_ip:
host, existing_port = parsed_host.netloc.rsplit(":", 1)
if existing_port.isdigit():
return f"{protocol}://{parsed_host.netloc}/v1"
if is_fqdn and self.engine.port is None:
return f"{protocol}://{encoded_host}/v1"
elif self.engine.port is not None:
return f"{protocol}://{encoded_host}:{self.engine.port}/v1"
else:
return f"{protocol}://{encoded_host}/v1"
@classmethod
def ensure_config_dir_exists(cls) -> None:
"""
Create the configuration directory if it does not exist.
"""
config_dir = Config.get_config_dir_path()
if not config_dir.exists():
config_dir.mkdir(parents=True, exist_ok=True)
@classmethod
def load_from_file(cls) -> "Config":
"""
Load the configuration from the TOML file in the configuration directory.
If no configuration file exists, this method will create a new one with default values.
The default configuration includes:
- An empty API configuration
- A default Engine configuration (host: "api.arcade-ai.com", port: None, tls: True)
- No user configuration
This behavior ensures that the application always has a valid configuration to work with,
but it may not be suitable for all use cases. If a specific configuration is required,
ensure that the configuration file exists before calling this method.
Returns:
Config: The loaded or newly created configuration.
Raises:
ValueError: If the existing configuration file is invalid.
"""
cls.ensure_config_dir_exists()
config_file_path = cls.get_config_file_path()
if not config_file_path.exists():
# Create a file using the default configuration
default_config = cls.model_construct(
api=ApiConfig.model_construct(), engine=EngineConfig()
)
default_config.save_to_file()
config_data = toml.loads(config_file_path.read_text())
try:
return cls(**config_data)
except ValidationError as e:
# Get only the errors with {type:missing} and combine them
# into a nicely-formatted string message.
# Any other errors without {type:missing} should just be str()ed
missing_field_errors = [
".".join(map(str, error["loc"]))
for error in e.errors()
if error["type"] == "missing"
]
other_errors = [str(error) for error in e.errors() if error["type"] != "missing"]
missing_field_errors_str = ", ".join(missing_field_errors)
other_errors_str = "\n".join(other_errors)
pretty_str: str = "Invalid Arcade configuration."
if missing_field_errors_str:
pretty_str += f"\nMissing fields: {missing_field_errors_str}\n"
if other_errors_str:
pretty_str += f"\nOther errors:\n{other_errors_str}"
raise ValueError(pretty_str) from e
def save_to_file(self) -> None:
"""
Save the configuration to the TOML file in the configuration directory.
"""
Config.ensure_config_dir_exists()
config_file_path = Config.get_config_file_path()
config_file_path.write_text(toml.dumps(self.model_dump()))

View file

@ -0,0 +1,135 @@
import pytest
from arcade.cli.utils import apply_config_overrides
from arcade.core.config_model import ApiConfig, Config, EngineConfig
DEFAULT_HOST = "api.arcade-ai.com"
DEFAULT_PORT = None
DEFAULT_TLS = True
@pytest.mark.parametrize(
"inputs, expected_outputs",
[
pytest.param(
{
"host_input": None,
"port_input": None,
"tls_input": None,
},
{
"host": DEFAULT_HOST,
"port": DEFAULT_PORT,
"tls": DEFAULT_TLS,
},
id="noop",
),
pytest.param(
{
"host_input": "api2.arcade-ai.com",
"port_input": None,
"tls_input": None,
},
{
"host": "api2.arcade-ai.com",
"port": DEFAULT_PORT,
"tls": DEFAULT_TLS,
},
id="set host",
),
pytest.param(
{
"host_input": None,
"port_input": 6789,
"tls_input": None,
},
{
"host": DEFAULT_HOST,
"port": 6789,
"tls": DEFAULT_TLS,
},
id="set port",
),
pytest.param(
{
"host_input": None,
"port_input": None,
"tls_input": False,
},
{
"host": DEFAULT_HOST,
"port": DEFAULT_PORT,
"tls": False,
},
id="set TLS to False",
),
pytest.param(
{
"host_input": None,
"port_input": None,
"tls_input": True,
},
{
"host": DEFAULT_HOST,
"port": DEFAULT_PORT,
"tls": True,
},
id="set TLS to True",
),
pytest.param(
{
"host_input": "localhost",
"port_input": None,
"tls_input": None,
},
{
"host": "localhost",
"port": 9099,
"tls": False,
},
id="localhost and no port or TLS specified",
),
pytest.param(
{
"host_input": "localhost",
"port_input": 1234,
"tls_input": None,
},
{
"host": "localhost",
"port": 1234,
"tls": False,
},
id="localhost and port specified",
),
pytest.param(
{
"host_input": "localhost",
"port_input": None,
"tls_input": True,
},
{
"host": "localhost",
"port": 9099,
"tls": True,
},
id="localhost and TLS specified",
),
],
)
def test_apply_config_overrides(inputs: dict, expected_outputs: dict):
# Set fake default values for testing
config = Config(
api=ApiConfig(key="fake_api_key"),
engine=EngineConfig(
host=DEFAULT_HOST,
port=DEFAULT_PORT,
tls=DEFAULT_TLS,
),
)
apply_config_overrides(config, inputs["host_input"], inputs["port_input"], inputs["tls_input"])
assert config.engine.host == expected_outputs["host"]
assert config.engine.port == expected_outputs["port"]
assert config.engine.tls == expected_outputs["tls"]

View file

@ -6,6 +6,7 @@ from httpx import HTTPStatusError, Response
from arcade.client import Arcade, AsyncArcade, AuthProvider
from arcade.client.errors import (
BadRequestError,
EngineNotHealthyError,
InternalServerError,
NotFoundError,
PermissionDeniedError,
@ -51,6 +52,15 @@ TOOL_AUTHORIZE_RESPONSE_DATA = {
"status": "pending",
}
HEALTH_CHECK_HEALTHY_RESPONSE_DATA = {
"healthy": True,
}
HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA = {
"healthy": False,
"reason": "Cannot reticulate splines",
}
@pytest.fixture
def mock_response():
@ -87,7 +97,7 @@ def test_handle_http_error(error_code, expected_error, mock_response):
mock_http_error = Mock(spec=HTTPStatusError)
mock_http_error.response = mock_response
client = Arcade() # Create an instance of Arcade
client = Arcade(api_key="fake_api_key") # Create an instance of Arcade
with pytest.raises(expected_error):
client._handle_http_error(mock_http_error) # Call the method on the instance
@ -95,7 +105,7 @@ def test_handle_http_error(error_code, expected_error, mock_response):
def test_arcade_auth_authorize(mock_response, monkeypatch):
"""Test Arcade.auth.authorize method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
client = Arcade()
client = Arcade(api_key="fake_api_key")
auth_response = client.auth.authorize(
provider=AuthProvider.google,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
@ -107,7 +117,7 @@ def test_arcade_auth_authorize(mock_response, monkeypatch):
def test_arcade_auth_poll_authorization(mock_response, monkeypatch):
"""Test Arcade.auth.poll_authorization method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
client = Arcade()
client = Arcade(api_key="fake_api_key")
auth_response = client.auth.status("auth_123")
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
@ -115,7 +125,7 @@ def test_arcade_auth_poll_authorization(mock_response, monkeypatch):
def test_arcade_tool_run(mock_response, monkeypatch):
"""Test Arcade.tool.run method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_RESPONSE_DATA)
client = Arcade()
client = Arcade(api_key="fake_api_key")
tool_response = client.tool.run(
tool_name="GetEmails",
user_id="sam@arcade-ai.com",
@ -128,7 +138,7 @@ def test_arcade_tool_run(mock_response, monkeypatch):
def test_arcade_tool_get(mock_response, monkeypatch):
"""Test Arcade.tool.get method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_DEFINITION_DATA)
client = Arcade()
client = Arcade(api_key="fake_api_key")
tool_definition = client.tool.get(director_id="default", tool_id="GetEmails")
assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA)
@ -138,11 +148,31 @@ def test_arcade_tool_authorize(mock_response, monkeypatch):
monkeypatch.setattr(
Arcade, "_execute_request", lambda *args, **kwargs: TOOL_AUTHORIZE_RESPONSE_DATA
)
client = Arcade()
client = Arcade(api_key="fake_api_key")
auth_response = client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com")
assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA)
def test_arcade_health_check(mock_response, monkeypatch):
"""Test Arcade.health.check method."""
monkeypatch.setattr(
Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_HEALTHY_RESPONSE_DATA
)
client = Arcade(api_key="fake_api_key")
client.health.check()
assert True # If no exception is raised, the test passes
def test_arcade_health_check_raises_error(mock_response, monkeypatch):
"""Test Arcade.health.check method."""
monkeypatch.setattr(
Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA
)
client = Arcade(api_key="fake_api_key")
with pytest.raises(EngineNotHealthyError):
client.health.check()
@pytest.mark.asyncio
async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch):
"""Test AsyncArcade.auth.authorize method."""
@ -151,7 +181,7 @@ async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch):
return AUTH_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade()
client = AsyncArcade(api_key="fake_api_key")
auth_response = await client.auth.authorize(
provider=AuthProvider.google,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
@ -168,7 +198,7 @@ async def test_async_arcade_auth_poll_authorization(mock_async_response, monkeyp
return AUTH_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade()
client = AsyncArcade(api_key="fake_api_key")
auth_response = await client.auth.status("auth_123")
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
@ -181,7 +211,7 @@ async def test_async_arcade_tool_run(mock_async_response, monkeypatch):
return TOOL_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade()
client = AsyncArcade(api_key="fake_api_key")
tool_response = await client.tool.run(
tool_name="GetEmails",
user_id="sam@arcade-ai.com",
@ -199,7 +229,7 @@ async def test_async_arcade_tool_get(mock_async_response, monkeypatch):
return TOOL_DEFINITION_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade()
client = AsyncArcade(api_key="fake_api_key")
tool_definition = await client.tool.get(director_id="default", tool_id="GetEmails")
assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA)
@ -212,6 +242,32 @@ async def test_async_arcade_tool_authorize(mock_async_response, monkeypatch):
return TOOL_AUTHORIZE_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade()
client = AsyncArcade(api_key="fake_api_key")
auth_response = await client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com")
assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA)
@pytest.mark.asyncio
async def test_async_arcade_health_check(mock_async_response, monkeypatch):
"""Test AsyncArcade.health.check method."""
async def mock_execute_request(*args, **kwargs):
return HEALTH_CHECK_HEALTHY_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
await client.health.check()
assert True # If no exception is raised, the test passes
@pytest.mark.asyncio
async def test_async_arcade_health_check_raises_error(mock_async_response, monkeypatch):
"""Test AsyncArcade.health.check method."""
async def mock_execute_request(*args, **kwargs):
return HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
with pytest.raises(EngineNotHealthyError):
await client.health.check()