From 75c6a2becf203924b03b892930ac88165a5c1d1c Mon Sep 17 00:00:00 2001 From: Nate Barbettini Date: Tue, 10 Sep 2024 09:25:05 -0700 Subject: [PATCH] arcade chat: allow overriding host, port, TLS (#31) Adds: - New options to `arcade chat`: `-h/--host`, `-p/--port`, and `--tls/--notls`. This allows us to point `arcade chat` at a different server than what's configured in `arcade.toml` which is very helpful for debugging. - Special case: if you do `-h localhost`, it will automatically use port 9099 and no TLS unless otherwise specified. - Adds a non-fatal engine health check to `arcade chat` startup: image --- arcade/arcade/cli/main.py | 75 ++++++++- arcade/arcade/cli/utils.py | 37 ++++- arcade/arcade/client/client.py | 74 +++++++++ arcade/arcade/client/errors.py | 19 +++ arcade/arcade/client/schema.py | 7 + arcade/arcade/core/config.py | 231 +------------------------- arcade/arcade/core/config_model.py | 251 +++++++++++++++++++++++++++++ arcade/tests/cli/test_utils.py | 135 ++++++++++++++++ arcade/tests/client/test_client.py | 78 +++++++-- 9 files changed, 657 insertions(+), 250 deletions(-) create mode 100644 arcade/arcade/core/config_model.py create mode 100644 arcade/tests/cli/test_utils.py diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index 8b67efe2..ba2e8202 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -16,12 +16,14 @@ from rich.text import Text from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login from arcade.cli.utils import ( OrderCommands, + apply_config_overrides, create_cli_catalog, display_streamed_markdown, markdownify_urls, validate_and_get_config, ) from arcade.client import Arcade +from arcade.client.errors import EngineNotHealthyError, EngineOfflineError cli = typer.Typer( cls=OrderCommands, @@ -30,7 +32,14 @@ console = Console() @cli.command(help="Log in to Arcade Cloud") -def login() -> None: +def login( + host: str = typer.Option( + "https://cloud.arcade-ai.com", + "-h", + "--host", + help="The Arcade Cloud host to log in to.", + ), +) -> None: """ Logs the user into Arcade Cloud. """ @@ -48,8 +57,7 @@ def login() -> None: # Open the browser for user login callback_uri = "http://localhost:9905/callback" params = urlencode({"callback_uri": callback_uri, "state": state}) - # TODO: make this configurable - login_url = f"http://localhost:8001/api/v1/auth/cli_login?{params}" + login_url = f"https://{host}/api/v1/auth/cli_login?{params}" console.print("Opening a browser to log you in...") webbrowser.open(login_url) @@ -131,12 +139,42 @@ def chat( stream: bool = typer.Option( False, "-s", "--stream", is_flag=True, help="Stream the tool output." ), + host: str = typer.Option( + None, + "-h", + "--host", + help="The Arcade Engine address to send chat requests to.", + ), + port: int = typer.Option( + None, + "-p", + "--port", + help="The port of the Arcade Engine.", + ), + force_tls: bool = typer.Option( + False, + "--tls", + help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.", + ), + force_no_tls: bool = typer.Option( + False, + "--no-tls", + help="Whether to disable TLS for the connection to the Arcade Engine.", + ), ) -> None: """ Chat with a language model. """ config = validate_and_get_config() + if not force_tls and not force_no_tls: + tls_input = None + elif force_no_tls: + tls_input = False + else: + tls_input = True + apply_config_overrides(config, host, port, tls_input) + client = Arcade(api_key=config.api.key, base_url=config.engine_url) user_email = config.user.email if config.user else None user_attribution = f"({user_email})" if user_email else "" @@ -153,10 +191,17 @@ def chat( ), "\n", "\n", - "Chatting with Arcade Engine at " + config.engine_url, + "Chatting with Arcade Engine at ", + ( + config.engine_url, + "bold blue", + ), ) console.print(chat_header) + # Try to hit /health endpoint on engine and warn if it is down + log_engine_health(client) + while True: console.print(f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] ") @@ -277,6 +322,28 @@ def config( raise typer.Exit(code=1) +def log_engine_health(client: Arcade) -> None: + try: + client.health.check() + + except EngineNotHealthyError as e: + console.print( + "[bold][yellow]⚠️ Warning: " + + str(e) + + " (" + + "[/yellow]" + + "[red]" + + str(e.status_code) + + "[/red]" + + "[yellow])[/yellow][/bold]" + ) + except EngineOfflineError: + console.print( + "⚠️ Warning: Arcade Engine was unreachable. (Is it running?)", + style="bold yellow", + ) + + def display_config_as_table(config) -> None: # type: ignore[no-untyped-def] """ Display the configuration details as a table using Rich library. diff --git a/arcade/arcade/cli/utils.py b/arcade/arcade/cli/utils.py index 030e62c5..cc8b325e 100644 --- a/arcade/arcade/cli/utils.py +++ b/arcade/arcade/cli/utils.py @@ -1,5 +1,3 @@ -from typing import TYPE_CHECKING - import typer from openai.resources.chat.completions import ChatCompletionChunk, Stream from rich.console import Console @@ -8,11 +6,9 @@ from typer.core import TyperGroup from typer.models import Context from arcade.core.catalog import ToolCatalog +from arcade.core.config_model import Config from arcade.core.toolkit import Toolkit -if TYPE_CHECKING: - from arcade.core.config import Config - console = Console() @@ -101,7 +97,7 @@ def validate_and_get_config( validate_engine: bool = True, validate_api: bool = True, validate_user: bool = True, -) -> "Config": +) -> Config: """ Validates the configuration, user, and returns the Config object """ @@ -125,3 +121,32 @@ def validate_and_get_config( raise typer.Exit(code=1) return config + + +def apply_config_overrides( + config: Config, host_input: str | None, port_input: int | None, tls_input: bool | None +) -> None: + """ + Apply optional config overrides (passed by the user) to the config object. + """ + + if not config.engine: + # Should not happen, validate_and_get_config ensures that `engine` is set + raise ValueError("Engine configuration not found in config.") + + # Special case for "localhost" and nothing else specified: + # default to dev port and no TLS for convenience + if host_input == "localhost": + if port_input is None: + port_input = 9099 + if tls_input is None: + tls_input = False + + if host_input: + config.engine.host = host_input + + if port_input is not None: + config.engine.port = port_input + + if tls_input is not None: + config.engine.tls = tls_input diff --git a/arcade/arcade/client/client.py b/arcade/arcade/client/client.py index 2a97b16b..48d22d0c 100644 --- a/arcade/arcade/client/client.py +++ b/arcade/arcade/client/client.py @@ -10,11 +10,13 @@ from arcade.client.base import ( BaseResource, SyncArcadeClient, ) +from arcade.client.errors import APIStatusError, EngineNotHealthyError, EngineOfflineError from arcade.client.schema import ( AuthProvider, AuthRequest, AuthResponse, ExecuteToolResponse, + HealthCheckResponse, ) from arcade.core.schema import ToolDefinition @@ -143,6 +145,41 @@ class ToolResource(BaseResource[ClientT]): return AuthResponse(**data) +class HealthResource(BaseResource[ClientT]): + """Health check resource.""" + + def check(self) -> None: + """ + Check the health of the Arcade Engine. + Raises an error if the health check fails. + """ + + try: + data = self._client._execute_request( # type: ignore[attr-defined] + "GET", + f"/{API_VERSION}/health", + timeout=5, + ) + + except APIStatusError as e: + raise EngineNotHealthyError( + "Arcade Engine health check returned an unhealthy status code", + status_code=e.status_code, + ) + except Exception as e: + # Catches everything else including httpx.ConnectError (most common) + raise EngineOfflineError(f"Arcade Engine was unreachable: {e}") + + health_check_response = HealthCheckResponse(**data) + + # Raise an error if the health payload is not `healthy: true` + if health_check_response.healthy is not True: + raise EngineNotHealthyError( + "Arcade Engine health check was not healthy", + status_code=200, + ) + + class AsyncAuthResource(BaseResource[AsyncArcadeClient]): """Asynchronous Authentication resource.""" @@ -234,6 +271,41 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]): return AuthResponse(**data) +class AsyncHealthResource(BaseResource[AsyncArcadeClient]): + """Asynchronous Health check resource.""" + + async def check(self) -> None: + """ + Check the health of the Arcade Engine. + Raises an error if the health check fails. + """ + + try: + data = await self._client._execute_request( # type: ignore[attr-defined] + "GET", + f"/{API_VERSION}/health", + timeout=5, + ) + + except APIStatusError as e: + raise EngineNotHealthyError( + "Arcade Engine health check returned an unhealthy status code", + status_code=e.status_code, + ) + except Exception as e: + # Catches everything else including httpx.ConnectError (most common) + raise EngineOfflineError(f"Arcade Engine was unreachable: {e}") + + health_check_response = HealthCheckResponse(**data) + + # Raise an error if the health payload is not `healthy: true` + if health_check_response.healthy is not True: + raise EngineNotHealthyError( + "Arcade Engine health check was not healthy", + status_code=200, + ) + + class Arcade(SyncArcadeClient): """Synchronous Arcade client.""" @@ -241,6 +313,7 @@ class Arcade(SyncArcadeClient): super().__init__(*args, **kwargs) self.auth: AuthResource = AuthResource(self) self.tool: ToolResource = ToolResource(self) + self.health: HealthResource = HealthResource(self) chat_url = self._chat_url(self._base_url) self._openai_client = OpenAI(base_url=chat_url, api_key=self._api_key) @@ -266,6 +339,7 @@ class AsyncArcade(AsyncArcadeClient): super().__init__(*args, **kwargs) self.auth: AsyncAuthResource = AsyncAuthResource(self) self.tool: AsyncToolResource = AsyncToolResource(self) + self.health: AsyncHealthResource = AsyncHealthResource(self) chat_url = self._chat_url(self._base_url) self._openai_client = AsyncOpenAI(base_url=chat_url, api_key=self._api_key) diff --git a/arcade/arcade/client/errors.py b/arcade/arcade/client/errors.py index e9879ca1..c9f6c648 100644 --- a/arcade/arcade/client/errors.py +++ b/arcade/arcade/client/errors.py @@ -9,6 +9,25 @@ class ArcadeError(Exception): pass +class EngineOfflineError(ArcadeError): + """Raised when the Arcade Engine is offline.""" + + def __init__(self, message: str): + super().__init__(message) + + +class EngineNotHealthyError(ArcadeError): + """Raised when the Arcade Engine is not healthy.""" + + def __init__( + self, + message: str, + status_code: int, + ): + super().__init__(message) + self.status_code = status_code + + class APIError(ArcadeError): """Base class for API-related errors.""" diff --git a/arcade/arcade/client/schema.py b/arcade/arcade/client/schema.py index 10f4285d..5ef4664c 100644 --- a/arcade/arcade/client/schema.py +++ b/arcade/arcade/client/schema.py @@ -42,6 +42,13 @@ class AuthStatus(str, Enum): completed = "completed" +class HealthCheckResponse(BaseModel): + """Response from a health check request.""" + + healthy: bool + """Whether the health check was successful.""" + + class AuthResponse(BaseModel): """Response from an authorization request.""" diff --git a/arcade/arcade/core/config.py b/arcade/arcade/core/config.py index d0342c3a..d77edfba 100644 --- a/arcade/arcade/core/config.py +++ b/arcade/arcade/core/config.py @@ -1,233 +1,6 @@ -import ipaddress -from functools import cached_property, lru_cache -from pathlib import Path -from urllib.parse import urlparse +from functools import lru_cache -import idna -import toml -from pydantic import BaseModel, ValidationError - -from arcade.core.env import settings - - -class ApiConfig(BaseModel): - """ - Arcade API configuration. - """ - - key: str - """ - Arcade API key. - """ - - -class UserConfig(BaseModel): - """ - Arcade user configuration. - """ - - email: str | None = None - """ - User email. - """ - - -class EngineConfig(BaseModel): - """ - Arcade Engine configuration. - """ - - host: str = "api.arcade-ai.com" - """ - Arcade Engine host. - """ - port: int | None = None - """ - Arcade Engine port. - """ - tls: bool = True - """ - Whether to use TLS for the connection to Arcade Engine. - """ - - -class Config(BaseModel): - """ - Configuration for Arcade. - """ - - api: ApiConfig - """ - Arcade API configuration. - """ - user: UserConfig | None = None - """ - Arcade user configuration. - """ - engine: EngineConfig | None = EngineConfig() - """ - Arcade Engine configuration. - """ - - @classmethod - def get_config_dir_path(cls) -> Path: - """ - Get the path to the Arcade configuration directory. - """ - return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade" - - @classmethod - def get_config_file_path(cls) -> Path: - """ - Get the path to the Arcade configuration file. - """ - return cls.get_config_dir_path() / "arcade.toml" - - @cached_property - def engine_url(self) -> str: - """ - Get the cached URL of the Arcade Engine. - - This property is cached after its first access to improve performance. - The cache is automatically invalidated if any of the underlying data changes. - - The port is included in the URL unless the host is a fully qualified domain name - (excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and - hostnames with underscores. - - This property exists to provide a consistent and correctly formatted URL for - connecting to the Arcade Engine, taking into account various configuration - options and edge cases. It ensures that: - - 1. The correct protocol (http/https) is used based on the TLS setting. - 2. IPv4 and IPv6 addresses are properly formatted. - 3. Internationalized Domain Names (IDNs) are correctly encoded. - 4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately. - 5. Ports are included when necessary, respecting common conventions for FQDNs. - 6. Hostnames with underscores (common in development environments) are supported. - 7. Pre-existing port specifications in the host are respected. - - The resulting URL is always suffixed with '/v1' to specify the API version. - - Returns: - str: The fully constructed URL for the Arcade Engine. - - Raises: - ValueError: If the engine configuration is missing or incomplete. - """ - if self.engine is None: - raise ValueError("Configuration for Engine is not set in arcade.toml") - if not self.engine.host: - raise ValueError("Configuration for Engine host is not set in arcade.toml") - - protocol = "https" if self.engine.tls else "http" - - # Handle potential IDNs - try: - encoded_host = idna.encode(self.engine.host).decode("ascii") - except idna.IDNAError: - encoded_host = self.engine.host - - # Check if the host is a valid IP address (IPv4 or IPv6) - try: - ipaddress.ip_address(encoded_host) - is_ip = True - except ValueError: - is_ip = False - - # Parse the host, handling potential IPv6 addresses - host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host - parsed_host = urlparse(f"//{host_for_parsing}") - - # Check if the host is a fully qualified domain name (excluding IP addresses) - is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc - - # Handle hosts that might already include a port - if ":" in parsed_host.netloc and not is_ip: - host, existing_port = parsed_host.netloc.rsplit(":", 1) - if existing_port.isdigit(): - return f"{protocol}://{parsed_host.netloc}/v1" - - if is_fqdn and self.engine.port is None: - return f"{protocol}://{encoded_host}/v1" - elif self.engine.port is not None: - return f"{protocol}://{encoded_host}:{self.engine.port}/v1" - else: - return f"{protocol}://{encoded_host}/v1" - - @classmethod - def ensure_config_dir_exists(cls) -> None: - """ - Create the configuration directory if it does not exist. - """ - config_dir = Config.get_config_dir_path() - if not config_dir.exists(): - config_dir.mkdir(parents=True, exist_ok=True) - - @classmethod - def load_from_file(cls) -> "Config": - """ - Load the configuration from the TOML file in the configuration directory. - - If no configuration file exists, this method will create a new one with default values. - The default configuration includes: - - An empty API configuration - - A default Engine configuration (host: "api.arcade-ai.com", port: None, tls: True) - - No user configuration - - This behavior ensures that the application always has a valid configuration to work with, - but it may not be suitable for all use cases. If a specific configuration is required, - ensure that the configuration file exists before calling this method. - - Returns: - Config: The loaded or newly created configuration. - - Raises: - ValueError: If the existing configuration file is invalid. - """ - cls.ensure_config_dir_exists() - - config_file_path = cls.get_config_file_path() - if not config_file_path.exists(): - # Create a file using the default configuration - default_config = cls.model_construct( - api=ApiConfig.model_construct(), engine=EngineConfig() - ) - default_config.save_to_file() - - config_data = toml.loads(config_file_path.read_text()) - - try: - return cls(**config_data) - except ValidationError as e: - # Get only the errors with {type:missing} and combine them - # into a nicely-formatted string message. - # Any other errors without {type:missing} should just be str()ed - missing_field_errors = [ - ".".join(map(str, error["loc"])) - for error in e.errors() - if error["type"] == "missing" - ] - other_errors = [str(error) for error in e.errors() if error["type"] != "missing"] - - missing_field_errors_str = ", ".join(missing_field_errors) - other_errors_str = "\n".join(other_errors) - - pretty_str: str = "Invalid Arcade configuration." - if missing_field_errors_str: - pretty_str += f"\nMissing fields: {missing_field_errors_str}\n" - if other_errors_str: - pretty_str += f"\nOther errors:\n{other_errors_str}" - - raise ValueError(pretty_str) from e - - def save_to_file(self) -> None: - """ - Save the configuration to the TOML file in the configuration directory. - """ - Config.ensure_config_dir_exists() - config_file_path = Config.get_config_file_path() - config_file_path.write_text(toml.dumps(self.model_dump())) +from arcade.core.config_model import Config @lru_cache(maxsize=1) diff --git a/arcade/arcade/core/config_model.py b/arcade/arcade/core/config_model.py new file mode 100644 index 00000000..16d749ff --- /dev/null +++ b/arcade/arcade/core/config_model.py @@ -0,0 +1,251 @@ +import ipaddress +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import idna +import toml +from pydantic import BaseModel, ValidationError + +from arcade.core.env import settings + + +class ApiConfig(BaseModel): + """ + Arcade API configuration. + """ + + key: str + """ + Arcade API key. + """ + + +class UserConfig(BaseModel): + """ + Arcade user configuration. + """ + + email: str | None = None + """ + User email. + """ + + +class EngineConfig(BaseModel): + """ + Arcade Engine configuration. + """ + + host: str = "api.arcade-ai.com" + """ + Arcade Engine host. + """ + port: int | None = None + """ + Arcade Engine port. + """ + tls: bool = True + """ + Whether to use TLS for the connection to Arcade Engine. + """ + + +class Config(BaseModel): + """ + Configuration for Arcade. + """ + + api: ApiConfig + """ + Arcade API configuration. + """ + user: UserConfig | None = None + """ + Arcade user configuration. + """ + engine: EngineConfig | None = EngineConfig() + """ + Arcade Engine configuration. + """ + + def __init__(self, **data: Any): + super().__init__(**data) + self._engine_url_cache: str | None = None + self._engine_url_cache_key: str | None = None + + @classmethod + def get_config_dir_path(cls) -> Path: + """ + Get the path to the Arcade configuration directory. + """ + return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade" + + @classmethod + def get_config_file_path(cls) -> Path: + """ + Get the path to the Arcade configuration file. + """ + return cls.get_config_dir_path() / "arcade.toml" + + def _generate_engine_url_cache_key(self) -> str: + """ + Generate a cache key for the engine_url property, based on its underlying data. + """ + if self.engine is None: + return "" + + return f"{self.engine.host}:{self.engine.port}:{self.engine.tls}" + + @property + def engine_url(self) -> str: + """ + Get the cached URL of the Arcade Engine. + + This property is cached after its first access to improve performance. + The cache is automatically invalidated if any of the underlying data changes. + + The port is included in the URL unless the host is a fully qualified domain name + (excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and + hostnames with underscores. + + This property exists to provide a consistent and correctly formatted URL for + connecting to the Arcade Engine, taking into account various configuration + options and edge cases. It ensures that: + + 1. The correct protocol (http/https) is used based on the TLS setting. + 2. IPv4 and IPv6 addresses are properly formatted. + 3. Internationalized Domain Names (IDNs) are correctly encoded. + 4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately. + 5. Ports are included when necessary, respecting common conventions for FQDNs. + 6. Hostnames with underscores (common in development environments) are supported. + 7. Pre-existing port specifications in the host are respected. + + The resulting URL is always suffixed with '/v1' to specify the API version. + + Returns: + str: The fully constructed URL for the Arcade Engine. + + Raises: + ValueError: If the engine configuration is missing or incomplete. + """ + current_cache_key = self._generate_engine_url_cache_key() + if self._engine_url_cache is None or self._engine_url_cache_key != current_cache_key: + self._engine_url_cache = self._compute_engine_url() + self._engine_url_cache_key = current_cache_key + return self._engine_url_cache + + def _compute_engine_url(self) -> str: + if self.engine is None: + raise ValueError("Configuration for Engine is not set in arcade.toml") + if not self.engine.host: + raise ValueError("Configuration for Engine host is not set in arcade.toml") + + protocol = "https" if self.engine.tls else "http" + + # Handle potential IDNs + try: + encoded_host = idna.encode(self.engine.host).decode("ascii") + except idna.IDNAError: + encoded_host = self.engine.host + + # Check if the host is a valid IP address (IPv4 or IPv6) + try: + ipaddress.ip_address(encoded_host) + is_ip = True + except ValueError: + is_ip = False + + # Parse the host, handling potential IPv6 addresses + host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host + parsed_host = urlparse(f"//{host_for_parsing}") + + # Check if the host is a fully qualified domain name (excluding IP addresses) + is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc + + # Handle hosts that might already include a port + if ":" in parsed_host.netloc and not is_ip: + host, existing_port = parsed_host.netloc.rsplit(":", 1) + if existing_port.isdigit(): + return f"{protocol}://{parsed_host.netloc}/v1" + + if is_fqdn and self.engine.port is None: + return f"{protocol}://{encoded_host}/v1" + elif self.engine.port is not None: + return f"{protocol}://{encoded_host}:{self.engine.port}/v1" + else: + return f"{protocol}://{encoded_host}/v1" + + @classmethod + def ensure_config_dir_exists(cls) -> None: + """ + Create the configuration directory if it does not exist. + """ + config_dir = Config.get_config_dir_path() + if not config_dir.exists(): + config_dir.mkdir(parents=True, exist_ok=True) + + @classmethod + def load_from_file(cls) -> "Config": + """ + Load the configuration from the TOML file in the configuration directory. + + If no configuration file exists, this method will create a new one with default values. + The default configuration includes: + - An empty API configuration + - A default Engine configuration (host: "api.arcade-ai.com", port: None, tls: True) + - No user configuration + + This behavior ensures that the application always has a valid configuration to work with, + but it may not be suitable for all use cases. If a specific configuration is required, + ensure that the configuration file exists before calling this method. + + Returns: + Config: The loaded or newly created configuration. + + Raises: + ValueError: If the existing configuration file is invalid. + """ + cls.ensure_config_dir_exists() + + config_file_path = cls.get_config_file_path() + if not config_file_path.exists(): + # Create a file using the default configuration + default_config = cls.model_construct( + api=ApiConfig.model_construct(), engine=EngineConfig() + ) + default_config.save_to_file() + + config_data = toml.loads(config_file_path.read_text()) + + try: + return cls(**config_data) + except ValidationError as e: + # Get only the errors with {type:missing} and combine them + # into a nicely-formatted string message. + # Any other errors without {type:missing} should just be str()ed + missing_field_errors = [ + ".".join(map(str, error["loc"])) + for error in e.errors() + if error["type"] == "missing" + ] + other_errors = [str(error) for error in e.errors() if error["type"] != "missing"] + + missing_field_errors_str = ", ".join(missing_field_errors) + other_errors_str = "\n".join(other_errors) + + pretty_str: str = "Invalid Arcade configuration." + if missing_field_errors_str: + pretty_str += f"\nMissing fields: {missing_field_errors_str}\n" + if other_errors_str: + pretty_str += f"\nOther errors:\n{other_errors_str}" + + raise ValueError(pretty_str) from e + + def save_to_file(self) -> None: + """ + Save the configuration to the TOML file in the configuration directory. + """ + Config.ensure_config_dir_exists() + config_file_path = Config.get_config_file_path() + config_file_path.write_text(toml.dumps(self.model_dump())) diff --git a/arcade/tests/cli/test_utils.py b/arcade/tests/cli/test_utils.py new file mode 100644 index 00000000..b99fe4ce --- /dev/null +++ b/arcade/tests/cli/test_utils.py @@ -0,0 +1,135 @@ +import pytest + +from arcade.cli.utils import apply_config_overrides +from arcade.core.config_model import ApiConfig, Config, EngineConfig + +DEFAULT_HOST = "api.arcade-ai.com" +DEFAULT_PORT = None +DEFAULT_TLS = True + + +@pytest.mark.parametrize( + "inputs, expected_outputs", + [ + pytest.param( + { + "host_input": None, + "port_input": None, + "tls_input": None, + }, + { + "host": DEFAULT_HOST, + "port": DEFAULT_PORT, + "tls": DEFAULT_TLS, + }, + id="noop", + ), + pytest.param( + { + "host_input": "api2.arcade-ai.com", + "port_input": None, + "tls_input": None, + }, + { + "host": "api2.arcade-ai.com", + "port": DEFAULT_PORT, + "tls": DEFAULT_TLS, + }, + id="set host", + ), + pytest.param( + { + "host_input": None, + "port_input": 6789, + "tls_input": None, + }, + { + "host": DEFAULT_HOST, + "port": 6789, + "tls": DEFAULT_TLS, + }, + id="set port", + ), + pytest.param( + { + "host_input": None, + "port_input": None, + "tls_input": False, + }, + { + "host": DEFAULT_HOST, + "port": DEFAULT_PORT, + "tls": False, + }, + id="set TLS to False", + ), + pytest.param( + { + "host_input": None, + "port_input": None, + "tls_input": True, + }, + { + "host": DEFAULT_HOST, + "port": DEFAULT_PORT, + "tls": True, + }, + id="set TLS to True", + ), + pytest.param( + { + "host_input": "localhost", + "port_input": None, + "tls_input": None, + }, + { + "host": "localhost", + "port": 9099, + "tls": False, + }, + id="localhost and no port or TLS specified", + ), + pytest.param( + { + "host_input": "localhost", + "port_input": 1234, + "tls_input": None, + }, + { + "host": "localhost", + "port": 1234, + "tls": False, + }, + id="localhost and port specified", + ), + pytest.param( + { + "host_input": "localhost", + "port_input": None, + "tls_input": True, + }, + { + "host": "localhost", + "port": 9099, + "tls": True, + }, + id="localhost and TLS specified", + ), + ], +) +def test_apply_config_overrides(inputs: dict, expected_outputs: dict): + # Set fake default values for testing + config = Config( + api=ApiConfig(key="fake_api_key"), + engine=EngineConfig( + host=DEFAULT_HOST, + port=DEFAULT_PORT, + tls=DEFAULT_TLS, + ), + ) + + apply_config_overrides(config, inputs["host_input"], inputs["port_input"], inputs["tls_input"]) + + assert config.engine.host == expected_outputs["host"] + assert config.engine.port == expected_outputs["port"] + assert config.engine.tls == expected_outputs["tls"] diff --git a/arcade/tests/client/test_client.py b/arcade/tests/client/test_client.py index 9853d8a6..4cb8f7c7 100644 --- a/arcade/tests/client/test_client.py +++ b/arcade/tests/client/test_client.py @@ -6,6 +6,7 @@ from httpx import HTTPStatusError, Response from arcade.client import Arcade, AsyncArcade, AuthProvider from arcade.client.errors import ( BadRequestError, + EngineNotHealthyError, InternalServerError, NotFoundError, PermissionDeniedError, @@ -51,6 +52,15 @@ TOOL_AUTHORIZE_RESPONSE_DATA = { "status": "pending", } +HEALTH_CHECK_HEALTHY_RESPONSE_DATA = { + "healthy": True, +} + +HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA = { + "healthy": False, + "reason": "Cannot reticulate splines", +} + @pytest.fixture def mock_response(): @@ -87,7 +97,7 @@ def test_handle_http_error(error_code, expected_error, mock_response): mock_http_error = Mock(spec=HTTPStatusError) mock_http_error.response = mock_response - client = Arcade() # Create an instance of Arcade + client = Arcade(api_key="fake_api_key") # Create an instance of Arcade with pytest.raises(expected_error): client._handle_http_error(mock_http_error) # Call the method on the instance @@ -95,7 +105,7 @@ def test_handle_http_error(error_code, expected_error, mock_response): def test_arcade_auth_authorize(mock_response, monkeypatch): """Test Arcade.auth.authorize method.""" monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA) - client = Arcade() + client = Arcade(api_key="fake_api_key") auth_response = client.auth.authorize( provider=AuthProvider.google, scopes=["https://www.googleapis.com/auth/gmail.readonly"], @@ -107,7 +117,7 @@ def test_arcade_auth_authorize(mock_response, monkeypatch): def test_arcade_auth_poll_authorization(mock_response, monkeypatch): """Test Arcade.auth.poll_authorization method.""" monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA) - client = Arcade() + client = Arcade(api_key="fake_api_key") auth_response = client.auth.status("auth_123") assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA) @@ -115,7 +125,7 @@ def test_arcade_auth_poll_authorization(mock_response, monkeypatch): def test_arcade_tool_run(mock_response, monkeypatch): """Test Arcade.tool.run method.""" monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_RESPONSE_DATA) - client = Arcade() + client = Arcade(api_key="fake_api_key") tool_response = client.tool.run( tool_name="GetEmails", user_id="sam@arcade-ai.com", @@ -128,7 +138,7 @@ def test_arcade_tool_run(mock_response, monkeypatch): def test_arcade_tool_get(mock_response, monkeypatch): """Test Arcade.tool.get method.""" monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_DEFINITION_DATA) - client = Arcade() + client = Arcade(api_key="fake_api_key") tool_definition = client.tool.get(director_id="default", tool_id="GetEmails") assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA) @@ -138,11 +148,31 @@ def test_arcade_tool_authorize(mock_response, monkeypatch): monkeypatch.setattr( Arcade, "_execute_request", lambda *args, **kwargs: TOOL_AUTHORIZE_RESPONSE_DATA ) - client = Arcade() + client = Arcade(api_key="fake_api_key") auth_response = client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com") assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA) +def test_arcade_health_check(mock_response, monkeypatch): + """Test Arcade.health.check method.""" + monkeypatch.setattr( + Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_HEALTHY_RESPONSE_DATA + ) + client = Arcade(api_key="fake_api_key") + client.health.check() + assert True # If no exception is raised, the test passes + + +def test_arcade_health_check_raises_error(mock_response, monkeypatch): + """Test Arcade.health.check method.""" + monkeypatch.setattr( + Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA + ) + client = Arcade(api_key="fake_api_key") + with pytest.raises(EngineNotHealthyError): + client.health.check() + + @pytest.mark.asyncio async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch): """Test AsyncArcade.auth.authorize method.""" @@ -151,7 +181,7 @@ async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch): return AUTH_RESPONSE_DATA monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) - client = AsyncArcade() + client = AsyncArcade(api_key="fake_api_key") auth_response = await client.auth.authorize( provider=AuthProvider.google, scopes=["https://www.googleapis.com/auth/gmail.readonly"], @@ -168,7 +198,7 @@ async def test_async_arcade_auth_poll_authorization(mock_async_response, monkeyp return AUTH_RESPONSE_DATA monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) - client = AsyncArcade() + client = AsyncArcade(api_key="fake_api_key") auth_response = await client.auth.status("auth_123") assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA) @@ -181,7 +211,7 @@ async def test_async_arcade_tool_run(mock_async_response, monkeypatch): return TOOL_RESPONSE_DATA monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) - client = AsyncArcade() + client = AsyncArcade(api_key="fake_api_key") tool_response = await client.tool.run( tool_name="GetEmails", user_id="sam@arcade-ai.com", @@ -199,7 +229,7 @@ async def test_async_arcade_tool_get(mock_async_response, monkeypatch): return TOOL_DEFINITION_DATA monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) - client = AsyncArcade() + client = AsyncArcade(api_key="fake_api_key") tool_definition = await client.tool.get(director_id="default", tool_id="GetEmails") assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA) @@ -212,6 +242,32 @@ async def test_async_arcade_tool_authorize(mock_async_response, monkeypatch): return TOOL_AUTHORIZE_RESPONSE_DATA monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) - client = AsyncArcade() + client = AsyncArcade(api_key="fake_api_key") auth_response = await client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com") assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA) + + +@pytest.mark.asyncio +async def test_async_arcade_health_check(mock_async_response, monkeypatch): + """Test AsyncArcade.health.check method.""" + + async def mock_execute_request(*args, **kwargs): + return HEALTH_CHECK_HEALTHY_RESPONSE_DATA + + monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) + client = AsyncArcade(api_key="fake_api_key") + await client.health.check() + assert True # If no exception is raised, the test passes + + +@pytest.mark.asyncio +async def test_async_arcade_health_check_raises_error(mock_async_response, monkeypatch): + """Test AsyncArcade.health.check method.""" + + async def mock_execute_request(*args, **kwargs): + return HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA + + monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) + client = AsyncArcade(api_key="fake_api_key") + with pytest.raises(EngineNotHealthyError): + await client.health.check()