Config Refactor (#116)

# PR Description
1. Removes `arcade config` CLI command and it's helper function.
2. Upon `arcade login`, if the user does not have an `arcade.env` file,
then a templated environment file is created for the user.
3. Removed `EngineConfig` and all references to it. Since there is no
longer an `EngineConfig`, this PR refactors the CLI to compute the
engine URL based on the command-line flags that were provided.
4. Renamed `arcade.toml` to `credentials.yaml`. If a user is using
`arcade.toml`, then we will display a deprecation message and then
automatically migrate their `arcade.toml` to `credentials.yaml`. NOTE:
Eventually this auto-migration support should be removed.
5. `arcade.env` is now an optional file
6. Make `arcade show` default to `https://api.arcade-ai.com/v1` instead
of localhost.
-------




## Ensuring engine url is still computed correctly:
I used the following matrix to ensure that the behavior has not changed
after the refactor. This matrix is tested in `test_utils.py`

DEFAULT_HOST = "api.arcade-ai.com"  
DEFAULT_PORT = None  
DEFAULT_FORCE_TLS = False  
DEFAULT_FORCE_NO_TLS = False  


| Command Line Arguments | Host | Port | Force TLS | Force No TLS |
Main's URL | This PR's URL |

|----------------------------------------|-----------------|---------------|-----------|--------------|-----------------------------------|-----------------------------------|
| | DEFAULT_HOST | DEFAULT_PORT | False | False |
https://api.arcade-ai.com/v1 | https://api.arcade-ai.com/v1 |
| --host localhost | localhost | DEFAULT_PORT | False | False |
http://localhost:9099/v1 | http://localhost:9099/v1 |
| -p 9099 | DEFAULT_HOST | 9099 | False | False |
https://api.arcade-ai.com:9099/v1 | https://api.arcade-ai.com:9099/v1 |
| --host localhost -p 9099 | localhost | 9099 | False | False |
http://localhost:9099/v1 | http://localhost:9099/v1 |
| --tls | DEFAULT_HOST | DEFAULT_PORT | True | False |
https://api.arcade-ai.com/v1 | https://api.arcade-ai.com/v1 |
| --host localhost --tls | localhost | DEFAULT_PORT | True | False |
https://localhost:9099/v1 | https://localhost:9099/v1 |
| -p 9099 --tls | DEFAULT_HOST | 9099 | True | False |
https://api.arcade-ai.com:9099/v1 | https://api.arcade-ai.com:9099/v1 |
| --host localhost -p 9099 --tls | localhost | 9099 | True | False |
https://localhost:9099/v1 | https://localhost:9099/v1 |
| --no-tls | DEFAULT_HOST | DEFAULT_PORT | False | True |
http://api.arcade-ai.com/v1 | http://api.arcade-ai.com/v1 |
| --host localhost --no-tls | localhost | DEFAULT_PORT | False | True |
http://localhost:9099/v1 | http://localhost:9099/v1 |
| -p 9099 --no-tls | DEFAULT_HOST | 9099 | False | True |
http://api.arcade-ai.com:9099/v1 | http://api.arcade-ai.com:9099/v1 |
| --host localhost -p 9099 --no-tls | localhost | 9099 | False | True |
http://localhost:9099/v1 | http://localhost:9099/v1 |
| --tls --no-tls | DEFAULT_HOST | DEFAULT_PORT | True | True |
http://api.arcade-ai.com/v1 | http://api.arcade-ai.com/v1 |
| --host localhost --tls --no-tls | localhost | DEFAULT_PORT | True |
True | http://localhost:9099/v1 | http://localhost:9099/v1 |
| -p 9099 --tls --no-tls | DEFAULT_HOST | 9099 | True | True |
http://api.arcade-ai.com:9099/v1 | http://api.arcade-ai.com:9099/v1 |
| --host localhost -p 9099 --tls --no-tls| localhost | 9099 | True |
True | http://localhost:9099/v1 | http://localhost:9099/v1 |
| --host arandomhost.com | arandomhost.com | DEFAULT_PORT | False |
False | https://arandomhost.com/v1 | https://arandomhost.com/v1 |
This commit is contained in:
Eric Gustin 2024-10-24 11:34:33 -07:00 committed by GitHub
parent 10030c6a12
commit 8508a28f54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 451 additions and 380 deletions

6
.gitignore vendored
View file

@ -1,6 +1,8 @@
.DS_Store
arcade.toml
docker/arcade.toml
arcade.toml # Deprecated in favor of credentials.yaml
credentials.yaml
docker/arcade.toml # Deprecated in favor of credentials.yaml
docker/credentials.yaml
*.lock

View file

@ -4,10 +4,11 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from urllib.parse import parse_qs
import toml
import yaml
from rich.console import Console
from arcade.cli.constants import LOGIN_FAILED_HTML, LOGIN_SUCCESS_HTML
from arcade.cli.utils import create_new_env_file, is_config_file_deprecated
console = Console()
@ -61,10 +62,10 @@ class LoginCallbackHandler(BaseHTTPRequestHandler):
os.makedirs(os.path.expanduser("~/.arcade"), exist_ok=True)
# TODO don't overwrite existing config
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
new_config = {"api": {"key": api_key}, "user": {"email": email}}
config_file_path = os.path.expanduser("~/.arcade/credentials.yaml")
new_config = {"cloud": {"api": {"key": api_key}, "user": {"email": email}}}
with open(config_file_path, "w") as f:
toml.dump(new_config, f)
yaml.dump(new_config, f)
# Send a success response to the browser
console.print(
@ -115,26 +116,37 @@ def check_existing_login() -> bool:
Check if the user is already logged in by verifying the config file.
Returns:
bool: True if the user is already logged in, False otherwise.
bool: True if the user is already logged in or is using the deprecated config file, False otherwise.
"""
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
if is_config_file_deprecated():
return True
# Create a new env file if one doesn't already exist
create_new_env_file()
config_file_path = os.path.expanduser("~/.arcade/credentials.yaml")
if not os.path.exists(config_file_path):
return False
try:
config: dict[str, Any] = toml.load(config_file_path)
api_key = config.get("api", {}).get("key")
email = config.get("user", {}).get("email")
if os.path.exists(config_file_path):
try:
with open(config_file_path) as f:
config: dict[str, Any] = yaml.safe_load(f)
api_key = config.get("api", {}).get("key")
email = config.get("user", {}).get("email")
if api_key and email:
if api_key and email:
console.print(
f"You're already logged in as {email}. "
f"Delete {config_file_path} to log in as a different user."
)
return True
except yaml.YAMLError:
console.print(
f"You're already logged in as {email}. "
f"Delete {config_file_path} to log in as a different user."
f"Error: Invalid configuration file at {config_file_path}", style="bold red"
)
return True
except toml.TomlDecodeError:
console.print(f"Error: Invalid configuration file at {config_file_path}", style="bold red")
except Exception as e:
console.print(f"Error: Unable to read configuration file: {e!s}", style="bold red")
except Exception as e:
console.print(f"Error: Unable to read configuration file: {e!s}", style="bold red")
return False
return True

View file

@ -1,3 +1,6 @@
DEFAULT_CLOUD_HOST = "cloud.arcade-ai.com"
DEFAULT_ENGINE_HOST = "api.arcade-ai.com"
_style_block = b"""
<link rel="icon" href="https://cdn.arcade-ai.com/favicons/favicon.ico" sizes="any">
<link rel="apple-touch-icon" href="https://cdn.arcade-ai.com/favicons/apple-touch-icon.png">

View file

@ -5,7 +5,6 @@ from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from arcade.core.config_model import Config
from arcade.core.schema import ToolDefinition
if TYPE_CHECKING:
@ -220,7 +219,7 @@ def _format_evaluation(evaluation: "EvaluationResult") -> str:
return "\n".join(result_lines)
def display_arcade_chat_header(config: Config, stream: bool) -> None:
def display_arcade_chat_header(base_url: str, stream: bool) -> None:
chat_header = Text.assemble(
"\n",
(
@ -231,35 +230,10 @@ def display_arcade_chat_header(config: Config, stream: bool) -> None:
"\n",
"Chatting with Arcade Engine at ",
(
config.engine_url,
base_url,
"bold blue",
),
)
if stream:
chat_header.append(" (streaming)")
console.print(chat_header)
def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
"""
Display the configuration details as a table using Rich library.
"""
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Section")
table.add_column("Name")
table.add_column("Value")
for section_name in config.model_dump():
section = getattr(config, section_name)
if section:
section = section.dict()
first = True
for name, value in section.items():
if first:
table.add_row(section_name, name, str(value))
first = False
else:
table.add_row("", name, str(value))
table.add_row("", "", "")
console.print(table)

View file

@ -52,7 +52,7 @@ def start_servers(
engine_config = _get_config_file(engine_config, default_filename="engine.yaml")
# Ensure engine_env is provided or found and either way, validated
env_file = _get_config_file(engine_env, default_filename="arcade.env")
env_file = _get_config_file(engine_env, default_filename="arcade.env", optional=True)
# Prepare command-line arguments for the actor server and engine
actor_cmd = _build_actor_command(host, port, debug)
@ -107,19 +107,22 @@ def _validate_port(port: int) -> int:
return port
def _get_config_file(file_path: str | None, default_filename: str = "engine.yaml") -> str:
def _get_config_file(
file_path: str | None, default_filename: str = "engine.yaml", optional: bool = False
) -> str | None:
"""
Determines and validates the config file path.
Args:
file_path: Optional path provided by the user.
default_filename: The default filename to look for.
optional: Whether the config file is optional.
Returns:
The resolved config file path.
The resolved config file path. None if the file is optional and not found.
Raises:
RuntimeError: If the config file is not found.
RuntimeError: If the config file is not found and is not optional.
"""
if file_path:
config_path = Path(os.path.expanduser(file_path)).resolve()
@ -147,8 +150,19 @@ def _get_config_file(file_path: str | None, default_filename: str = "engine.yaml
console.print(f"Using config file at {etc_path}", style="bold green")
return str(etc_path)
if optional:
console.print(
f"⚠️ Optional config file '{default_filename}' not found in either of the default locations: "
f"1) current working directory: {Path.cwd() / default_filename}, or "
f"2) user's home directory: {Path.home() / '.arcade' / default_filename}.",
style="bold yellow",
)
return None
console.print(
f"❌ Config file '{default_filename}' not found in any of the default locations.",
f"❌ Config file '{default_filename}' not found in any of the default locations: "
f"1) current working directory: {Path.cwd() / default_filename}, or "
f"2) user's home directory: {Path.home() / '.arcade' / default_filename}.",
style="bold red",
)
raise RuntimeError(f"Config file '{default_filename}' not found.")
@ -187,7 +201,7 @@ def _build_actor_command(host: str, port: int, debug: bool) -> list[str]:
return cmd
def _build_engine_command(engine_config: str, engine_env: str | None = None) -> list[str]:
def _build_engine_command(engine_config: str | None, engine_env: str | None = None) -> list[str]:
"""
Builds the command to start the engine.
@ -198,6 +212,11 @@ def _build_engine_command(engine_config: str, engine_env: str | None = None) ->
Returns:
The command as a list.
"""
# This should never happen, but we'll check regardless
if not engine_config:
console.print("❌ Engine configuration file not found", style="bold red")
sys.exit(1)
engine_bin = shutil.which("arcade-engine")
if not engine_bin:
console.print(

View file

@ -16,9 +16,9 @@ from rich.markup import escape
from rich.text import Text
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
from arcade.cli.constants import DEFAULT_CLOUD_HOST, DEFAULT_ENGINE_HOST
from arcade.cli.display import (
display_arcade_chat_header,
display_config_as_table,
display_eval_results,
display_tool_details,
display_tool_messages,
@ -27,8 +27,9 @@ from arcade.cli.display import (
from arcade.cli.launcher import start_servers
from arcade.cli.utils import (
OrderCommands,
compute_base_url,
create_cli_catalog,
get_config_with_overrides,
delete_deprecated_config_file,
get_eval_files,
get_tools_from_engine,
handle_chat_interaction,
@ -53,7 +54,7 @@ console = Console()
@cli.command(help="Log in to Arcade Cloud", rich_help_panel="User")
def login(
host: str = typer.Option(
"cloud.arcade-ai.com",
DEFAULT_CLOUD_HOST,
"-h",
"--host",
help="The Arcade Cloud host to log in to.",
@ -99,9 +100,10 @@ def logout() -> None:
"""
Logs the user out of Arcade Cloud.
"""
delete_deprecated_config_file()
# If ~/.arcade/arcade.toml exists, delete it
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
# If ~/.arcade/credentials.yaml exists, delete it
config_file_path = os.path.expanduser("~/.arcade/credentials.yaml")
if os.path.exists(config_file_path):
os.remove(config_file_path)
console.print("You're now logged out.", style="bold")
@ -136,7 +138,7 @@ def show(
tool: Optional[str] = typer.Option(
None, "-t", "--tool", help="The specific tool to show details for"
),
host: Optional[str] = typer.Option(
host: str = typer.Option(
None,
"-h",
"--host",
@ -207,7 +209,7 @@ def chat(
prompt: str = typer.Option(None, "--prompt", help="The system prompt to use for the chat."),
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
host: str = typer.Option(
None,
DEFAULT_ENGINE_HOST,
"-h",
"--host",
help="The Arcade Engine address to send chat requests to.",
@ -232,9 +234,10 @@ def chat(
"""
Chat with a language model.
"""
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
config = validate_and_get_config()
base_url = compute_base_url(force_tls, force_no_tls, host, port)
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
client = Arcade(api_key=config.api.key, base_url=base_url)
user_email = config.user.email if config.user else None
try:
@ -244,7 +247,7 @@ def chat(
if prompt:
history.append({"role": "system", "content": prompt})
display_arcade_chat_header(config, stream)
display_arcade_chat_header(base_url, stream)
# Try to hit /health endpoint on engine and warn if it is down
log_engine_health(client)
@ -262,7 +265,7 @@ def chat(
try:
# TODO fixup configuration to remove this + "/v1" workaround
openai_client = OpenAI(api_key=config.api.key, base_url=config.engine_url + "/v1")
openai_client = OpenAI(api_key=config.api.key, base_url=base_url + "/v1")
chat_result = handle_chat_interaction(
openai_client, model, history, user_email, stream
)
@ -301,48 +304,6 @@ def chat(
raise typer.Exit()
@cli.command(help="Show/edit the local Arcade configuration", rich_help_panel="User")
def config(
action: str = typer.Argument("show", help="The action to take (show/edit)"),
key: str = typer.Option(
None, "--key", "-k", help="The configuration key to edit (e.g., 'api.key')"
),
val: str = typer.Option(None, "--val", "-v", help="The value of the configuration to edit"),
) -> None:
"""
Show/edit configuration details of the Arcade Engine
"""
config = validate_and_get_config()
if action == "show":
display_config_as_table(config)
elif action == "edit":
if not key or val is None:
console.print("❌ Key and value must be provided for editing.", style="bold red")
raise typer.Exit(code=1)
keys = key.split(".")
if len(keys) != 2:
console.print("❌ Invalid key format. Use 'section.name' format.", style="bold red")
raise typer.Exit(code=1)
section, name = keys
section_dict = getattr(config, section, None)
if section_dict and hasattr(section_dict, name):
setattr(section_dict, name, val)
config.save_to_file()
console.print("✅ Configuration updated successfully.", style="bold green")
else:
console.print(
f"❌ Invalid configuration name: {name} in section: {section}",
style="bold red",
)
raise typer.Exit(code=1)
else:
console.print(f"❌ Invalid action: {action}", style="bold red")
raise typer.Exit(code=1)
@cli.command(help="Run tool calling evaluations", rich_help_panel="Tool Development")
def evals(
directory: str = typer.Argument(".", help="Directory containing evaluation files"),
@ -360,7 +321,7 @@ def evals(
help="The models to use for evaluation (default: gpt-4o)",
),
host: str = typer.Option(
None,
DEFAULT_ENGINE_HOST,
"-h",
"--host",
help="The Arcade Engine address to send chat requests to.",
@ -386,7 +347,8 @@ def evals(
Find all files starting with 'eval_' in the given directory,
execute any functions decorated with @tool_eval, and display the results.
"""
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
config = validate_and_get_config()
base_url = compute_base_url(force_tls, force_no_tls, host, port)
models_list = models.split(",") # Use 'models_list' to avoid shadowing
@ -398,12 +360,12 @@ def evals(
console.print(
Text.assemble(
("\nRunning evaluations against Arcade Engine at ", "bold"),
(config.engine_url, "bold blue"),
(base_url, "bold blue"),
)
)
# Try to hit /health endpoint on engine and warn if it is down
with Arcade(api_key=config.api.key, base_url=config.engine_url) as client:
with Arcade(api_key=config.api.key, base_url=base_url) as client:
log_engine_health(client)
# Use the new function to load eval suites
@ -432,7 +394,12 @@ def evals(
)
for model in models_list:
task = asyncio.create_task(
suite_func(config=config, model=model, max_concurrency=max_concurrent)
suite_func(
config=config,
base_url=base_url,
model=model,
max_concurrency=max_concurrent,
)
)
tasks.append(task)

View file

@ -1,9 +1,13 @@
import importlib.util
import ipaddress
import os
import webbrowser
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Union, cast
from urllib.parse import urlparse
import idna
import typer
from arcadepy import NOT_GIVEN, APIConnectionError, APIStatusError, APITimeoutError, Arcade
from arcadepy.types import AuthorizationResponse
@ -66,25 +70,85 @@ def create_cli_catalog(
return catalog
def get_config_with_overrides(
def compute_base_url(
force_tls: bool,
force_no_tls: bool,
host_input: str | None = None,
port_input: int | None = None,
) -> Config:
host: str,
port: int | None,
) -> str:
"""
Get the config with CLI-specific optional overrides applied.
"""
config = validate_and_get_config()
Compute the base URL for the Arcade Engine from the provided overrides.
if not force_tls and not force_no_tls:
tls_input = None
elif force_no_tls:
tls_input = False
force_no_tls takes precedence over force_tls. For example, if both are set to True,
the resulting URL will use http.
The port is included in the URL unless the host is a fully qualified domain name
(excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and
hostnames with underscores.
This property exists to provide a consistent and correctly formatted URL for
connecting to the Arcade Engine, taking into account various configuration
options and edge cases. It ensures that:
1. The correct protocol (http/https) is used based on the TLS setting.
2. IPv4 and IPv6 addresses are properly formatted.
3. Internationalized Domain Names (IDNs) are correctly encoded.
4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately.
5. Ports are included when necessary, respecting common conventions for FQDNs.
6. Hostnames with underscores (common in development environments) are supported.
7. Pre-existing port specifications in the host are respected.
The resulting URL is always suffixed with api_version to specify the API version.
Returns:
str: The fully constructed URL for the Arcade Engine.
"""
# Determine TLS setting based on input flags
if force_no_tls:
is_tls = False
elif force_tls:
is_tls = True
else:
tls_input = True
apply_config_overrides(config, host_input, port_input, tls_input)
return config
is_tls = host != "localhost"
# "localhost" defaults to dev port if not specified
if host == "localhost" and port is None:
port = 9099
protocol = "https" if is_tls else "http"
# Handle potential IDNs
try:
encoded_host = idna.encode(host).decode("ascii")
except idna.IDNAError:
encoded_host = host
# Check if the host is a valid IP address (IPv4 or IPv6)
try:
ipaddress.ip_address(encoded_host)
is_ip = True
except ValueError:
is_ip = False
# Parse the host, handling potential IPv6 addresses
host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host
parsed_host = urlparse(f"//{host_for_parsing}")
# Check if the host is a fully qualified domain name (excluding IP addresses)
is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc
# Handle hosts that might already include a port
if ":" in parsed_host.netloc and not is_ip:
host, existing_port = parsed_host.netloc.rsplit(":", 1)
if existing_port.isdigit():
return f"{protocol}://{parsed_host.netloc}"
if is_fqdn and port is None:
return f"{protocol}://{encoded_host}"
elif port is not None:
return f"{protocol}://{encoded_host}:{port}"
else:
return f"{protocol}://{encoded_host}"
def get_tools_from_engine(
@ -94,8 +158,9 @@ def get_tools_from_engine(
force_no_tls: bool = False,
toolkit: str | None = None,
) -> list[ToolDefinition]:
config = get_config_with_overrides(force_tls, force_no_tls, host, port)
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
config = validate_and_get_config()
base_url = compute_base_url(force_tls, force_no_tls, host, port)
client = Arcade(api_key=config.api.key, base_url=base_url)
tools = []
page_iterator = client.tools.list(toolkit=toolkit or NOT_GIVEN)
@ -177,7 +242,6 @@ def markdownify_urls(message: str) -> str:
def validate_and_get_config(
validate_engine: bool = True,
validate_api: bool = True,
validate_user: bool = True,
) -> Config:
@ -186,10 +250,6 @@ def validate_and_get_config(
"""
from arcade.core.config import config
if validate_engine and (not config.engine or not config.engine_url):
console.print("❌ Engine configuration not found or URL is missing.", style="bold red")
raise typer.Exit(code=1)
if validate_api and (not config.api or not config.api.key):
console.print(
"❌ API configuration not found or key is missing. Please run `arcade login`.",
@ -206,35 +266,6 @@ def validate_and_get_config(
return config
def apply_config_overrides(
config: Config, host_input: str | None, port_input: int | None, tls_input: bool | None
) -> None:
"""
Apply optional config overrides (passed by the user) to the config object.
"""
if not config.engine:
# Should not happen, validate_and_get_config ensures that `engine` is set
raise ValueError("Engine configuration not found in config.")
# Special case for "localhost" and nothing else specified:
# default to dev port and no TLS for convenience
if host_input == "localhost":
if port_input is None:
port_input = 9099
if tls_input is None:
tls_input = False
if host_input:
config.engine.host = host_input
if port_input is not None:
config.engine.port = port_input
if tls_input is not None:
config.engine.tls = tls_input
def log_engine_health(client: Arcade) -> None:
try:
result = client.health.check(timeout=2)
@ -488,3 +519,48 @@ def load_eval_suites(eval_files: list[Path]) -> list[Callable]:
eval_suites.extend(eval_suite_funcs)
return eval_suites
def create_new_env_file() -> None:
"""
Create a new env file if one doesn't already exist.
"""
env_file = os.path.expanduser("~/.arcade/arcade.env")
if not os.path.exists(env_file):
template_path = os.path.join(
os.path.dirname(__file__), "..", "templates", "arcade.template.env"
)
os.makedirs(os.path.dirname(env_file), exist_ok=True)
with open(template_path) as template_file, open(env_file, "w") as new_env_file:
template_contents = template_file.read()
new_env_file.write(template_contents)
console.print(f"Created new environment file at {env_file}", style="bold green")
def is_config_file_deprecated() -> bool:
"""
Check if the user is using the deprecated config file.
Returns:
bool: True if the user is using the deprecated config file, False otherwise.
"""
deprecated_config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
if os.path.exists(deprecated_config_file_path):
console.print(
f"Deprecation Notice: You are using a deprecated config file at {deprecated_config_file_path}. Please migrate to the new format by running,\n\n\t$ arcade logout && arcade login\n",
style="bold yellow",
)
return True
return False
def delete_deprecated_config_file() -> None:
"""
Delete the deprecated config file if it exists.
"""
deprecated_config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
if os.path.exists(deprecated_config_file_path):
os.remove(deprecated_config_file_path)

View file

@ -1,11 +1,9 @@
import ipaddress
import os
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import idna
import toml
import yaml
from pydantic import BaseModel, ConfigDict, ValidationError
@ -39,25 +37,6 @@ class UserConfig(BaseConfig):
"""
class EngineConfig(BaseConfig):
"""
Arcade Engine configuration.
"""
host: str = "api.arcade-ai.com"
"""
Arcade Engine host.
"""
port: int | None = None
"""
Arcade Engine port.
"""
tls: bool = True
"""
Whether to use TLS for the connection to Arcade Engine.
"""
class Config(BaseConfig):
"""
Configuration for Arcade.
@ -71,15 +50,9 @@ class Config(BaseConfig):
"""
Arcade user configuration.
"""
engine: EngineConfig | None = EngineConfig()
"""
Arcade Engine configuration.
"""
def __init__(self, **data: Any):
super().__init__(**data)
self._engine_url_cache: str | None = None
self._engine_url_cache_key: str | None = None
@classmethod
def get_config_dir_path(cls) -> Path:
@ -89,98 +62,19 @@ class Config(BaseConfig):
config_path = os.getenv("ARCADE_WORK_DIR") or Path.home() / ".arcade"
return Path(config_path).resolve()
@classmethod
def get_deprecated_config_file_path(cls) -> Path:
"""
Get the path to the deprecated Arcade configuration file.
"""
return cls.get_config_dir_path() / "arcade.toml"
@classmethod
def get_config_file_path(cls) -> Path:
"""
Get the path to the Arcade configuration file.
"""
return cls.get_config_dir_path() / "arcade.toml"
def _generate_engine_url_cache_key(self) -> str:
"""
Generate a cache key for the engine_url property, based on its underlying data.
"""
if self.engine is None:
return ""
return f"{self.engine.host}:{self.engine.port}:{self.engine.tls}"
@property
def engine_url(self) -> str:
"""
Get the cached URL of the Arcade Engine.
This property is cached after its first access to improve performance.
The cache is automatically invalidated if any of the underlying data changes.
The port is included in the URL unless the host is a fully qualified domain name
(excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and
hostnames with underscores.
This property exists to provide a consistent and correctly formatted URL for
connecting to the Arcade Engine, taking into account various configuration
options and edge cases. It ensures that:
1. The correct protocol (http/https) is used based on the TLS setting.
2. IPv4 and IPv6 addresses are properly formatted.
3. Internationalized Domain Names (IDNs) are correctly encoded.
4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately.
5. Ports are included when necessary, respecting common conventions for FQDNs.
6. Hostnames with underscores (common in development environments) are supported.
7. Pre-existing port specifications in the host are respected.
Returns:
str: The fully constructed URL for the Arcade Engine.
Raises:
ValueError: If the engine configuration is missing or incomplete.
"""
current_cache_key = self._generate_engine_url_cache_key()
if self._engine_url_cache is None or self._engine_url_cache_key != current_cache_key:
self._engine_url_cache = self._compute_engine_url()
self._engine_url_cache_key = current_cache_key
return self._engine_url_cache
def _compute_engine_url(self) -> str:
if self.engine is None:
raise ValueError("Configuration for Engine is not set in arcade.toml")
if not self.engine.host:
raise ValueError("Configuration for Engine host is not set in arcade.toml")
protocol = "https" if self.engine.tls else "http"
# Handle potential IDNs
try:
encoded_host = idna.encode(self.engine.host).decode("ascii")
except idna.IDNAError:
encoded_host = self.engine.host
# Check if the host is a valid IP address (IPv4 or IPv6)
try:
ipaddress.ip_address(encoded_host)
is_ip = True
except ValueError:
is_ip = False
# Parse the host, handling potential IPv6 addresses
host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host
parsed_host = urlparse(f"//{host_for_parsing}")
# Check if the host is a fully qualified domain name (excluding IP addresses)
is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc
# Handle hosts that might already include a port
if ":" in parsed_host.netloc and not is_ip:
host, existing_port = parsed_host.netloc.rsplit(":", 1)
if existing_port.isdigit():
return f"{protocol}://{parsed_host.netloc}"
if is_fqdn and self.engine.port is None:
return f"{protocol}://{encoded_host}"
elif self.engine.port is not None:
return f"{protocol}://{encoded_host}:{self.engine.port}"
else:
return f"{protocol}://{encoded_host}"
return cls.get_config_dir_path() / "credentials.yaml"
@classmethod
def ensure_config_dir_exists(cls) -> None:
@ -194,7 +88,7 @@ class Config(BaseConfig):
@classmethod
def load_from_file(cls) -> "Config":
"""
Load the configuration from the TOML file in the configuration directory.
Load the configuration from the YAML file in the configuration directory.
If no configuration file exists, this method will create a new one with default values.
The default configuration includes:
@ -202,30 +96,30 @@ class Config(BaseConfig):
- A default Engine configuration (host: "api.arcade-ai.com", port: None, tls: True)
- No user configuration
This behavior ensures that the application always has a valid configuration to work with,
but it may not be suitable for all use cases. If a specific configuration is required,
ensure that the configuration file exists before calling this method.
If a deprecated TOML configuration file is found, it will be automatically converted
to the new YAML format. This ensures that the application always has a valid configuration
to work with, but it may not be suitable for all use cases. If a specific configuration
is required, ensure that the configuration file exists before calling this method.
Returns:
Config: The loaded or newly created configuration.
Raises:
ValueError: If the existing configuration file is invalid.
ValueError: If the existing configuration file is invalid or cannot be converted.
"""
cls.ensure_config_dir_exists()
config_file_path = cls.get_config_file_path()
if not config_file_path.exists():
if not config_file_path.exists() and not cls._migrate_deprecated_config_file():
# Create a file using the default configuration
default_config = cls.model_construct(
api=ApiConfig.model_construct(), engine=EngineConfig()
)
default_config = cls.model_construct(api=ApiConfig.model_construct())
default_config.save_to_file()
config_data = toml.loads(config_file_path.read_text())
config_data = yaml.safe_load(config_file_path.read_text())
try:
return cls(**config_data)
return cls(**config_data["cloud"])
except ValidationError as e:
# Get only the errors with {type:missing} and combine them
# into a nicely-formatted string message.
@ -250,8 +144,36 @@ class Config(BaseConfig):
def save_to_file(self) -> None:
"""
Save the configuration to the TOML file in the configuration directory.
Save the configuration to the YAML file in the configuration directory.
"""
Config.ensure_config_dir_exists()
config_file_path = Config.get_config_file_path()
config_file_path.write_text(toml.dumps(self.model_dump()))
config_file_path.write_text(yaml.dump(self.model_dump()))
@classmethod
def _migrate_deprecated_config_file(cls) -> bool:
"""
Migrate the deprecated config file to the new format if the deprecated config file exists.
Returns:
bool: True if the migration occurred, False otherwise.
"""
deprecated_config_file_path = Config.get_deprecated_config_file_path()
if deprecated_config_file_path.exists():
# If the user is using the deprecated config file, then convert it to the new yaml format
try:
old_config: dict[str, Any] = toml.load(deprecated_config_file_path)
old_config = {"cloud": old_config}
with open(cls.get_config_file_path(), "w") as f:
yaml.dump(old_config, f)
os.remove(deprecated_config_file_path)
print(
f"\033[1;33mAutomatically migrated the deprecated config file {deprecated_config_file_path} to {cls.get_config_file_path()}\033[0m"
)
except Exception as e:
raise OSError(
f"Invalid configuration file at {deprecated_config_file_path} could not be automatically converted to the new format. Please manually migrate to {cls.get_config_file_path()} by running `arcade logout && arcade login`."
) from e
return True
return False

View file

@ -644,6 +644,7 @@ def tool_eval() -> Callable[[Callable], Callable]:
@functools.wraps(func)
async def wrapper(
config: Config,
base_url: str,
model: str,
max_concurrency: int = 1,
) -> list[dict[str, Any]]:
@ -654,7 +655,7 @@ def tool_eval() -> Callable[[Callable], Callable]:
results = []
async with AsyncOpenAI(
api_key=config.api.key,
base_url=config.engine_url + "/v1", # TODO remove
base_url=base_url + "/v1",
) as client:
result = await suite.run(client, model)
results.append(result)

View file

@ -0,0 +1,34 @@
### LLM API KEY ###
# ANTHROPIC_API_KEY=...
# OPENAI_API_KEY=...
# ...
### Integrations ###
# GITHUB_CLIENT_ID=...
# GITHUB_CLIENT_SECRET=...
# GOOGLE_CLIENT_ID=...
# GOOGLE_CLIENT_SECRET=...
# LINKEDIN_CLIENT_ID=...
# LINKEDIN_CLIENT_SECRET=...
# MICROSOFT_CLIENT_ID=...
# MICROSOFT_CLIENT_SECRET=...
# SLACK_CLIENT_ID=...
# SLACK_CLIENT_SECRET=...
# SPOTIFY_CLIENT_ID=...
# SPOTIFY_CLIENT_SECRET=...
# X_CLIENT_ID=...
# X_CLIENT_SECRET=...
# ZOOM_CLIENT_ID=...
# ZOOM_CLIENT_SECRET=...
#...
# ...

View file

@ -50,6 +50,7 @@ pytest-asyncio = "^0.23.7"
types-toml = "^0.10.8"
types-pytz = "^2024.1"
types-python-dateutil = "^2.8.2"
types-PyYAML = "^6.0.0"
poetry-plugin-export = "^1.7.0"
[tool.poetry.scripts]

View file

@ -1,135 +1,195 @@
import pytest
from arcade.cli.utils import apply_config_overrides
from arcade.core.config_model import ApiConfig, Config, EngineConfig
from arcade.cli.utils import compute_base_url
DEFAULT_HOST = "api.arcade-ai.com"
LOCALHOST = "localhost"
DEFAULT_PORT = None
DEFAULT_TLS = True
DEFAULT_FORCE_TLS = False
DEFAULT_FORCE_NO_TLS = False
@pytest.mark.parametrize(
"inputs, expected_outputs",
"inputs, expected_output",
[
pytest.param(
{
"host_input": None,
"port_input": None,
"tls_input": None,
"host_input": DEFAULT_HOST,
"port_input": DEFAULT_PORT,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
{
"host": DEFAULT_HOST,
"port": DEFAULT_PORT,
"tls": DEFAULT_TLS,
},
id="noop",
"https://api.arcade-ai.com",
id="default",
),
pytest.param(
{
"host_input": "api2.arcade-ai.com",
"port_input": None,
"tls_input": None,
"host_input": LOCALHOST,
"port_input": DEFAULT_PORT,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
{
"host": "api2.arcade-ai.com",
"port": DEFAULT_PORT,
"tls": DEFAULT_TLS,
},
id="set host",
"http://localhost:9099",
id="localhost",
),
pytest.param(
{
"host_input": None,
"port_input": 6789,
"tls_input": None,
"host_input": DEFAULT_HOST,
"port_input": 9099,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
{
"host": DEFAULT_HOST,
"port": 6789,
"tls": DEFAULT_TLS,
},
id="set port",
"https://api.arcade-ai.com:9099",
id="custom port",
),
pytest.param(
{
"host_input": None,
"port_input": None,
"tls_input": False,
"host_input": LOCALHOST,
"port_input": 9099,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
{
"host": DEFAULT_HOST,
"port": DEFAULT_PORT,
"tls": False,
},
id="set TLS to False",
"http://localhost:9099",
id="localhost with custom port",
),
pytest.param(
{
"host_input": None,
"port_input": None,
"tls_input": True,
"host_input": DEFAULT_HOST,
"port_input": DEFAULT_PORT,
"force_tls": True,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
{
"host": DEFAULT_HOST,
"port": DEFAULT_PORT,
"tls": True,
},
id="set TLS to True",
"https://api.arcade-ai.com",
id="force TLS",
),
pytest.param(
{
"host_input": "localhost",
"port_input": None,
"tls_input": None,
"host_input": LOCALHOST,
"port_input": DEFAULT_PORT,
"force_tls": True,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
{
"host": "localhost",
"port": 9099,
"tls": False,
},
id="localhost and no port or TLS specified",
"https://localhost:9099",
id="localhost with force TLS",
),
pytest.param(
{
"host_input": "localhost",
"port_input": 1234,
"tls_input": None,
"host_input": DEFAULT_HOST,
"port_input": 9099,
"force_tls": True,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
{
"host": "localhost",
"port": 1234,
"tls": False,
},
id="localhost and port specified",
"https://api.arcade-ai.com:9099",
id="custom port with force TLS",
),
pytest.param(
{
"host_input": "localhost",
"port_input": None,
"tls_input": True,
"host_input": LOCALHOST,
"port_input": 9099,
"force_tls": True,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
"https://localhost:9099",
id="localhost with custom port and force TLS",
),
pytest.param(
{
"host": "localhost",
"port": 9099,
"tls": True,
"host_input": DEFAULT_HOST,
"port_input": DEFAULT_PORT,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": True,
},
id="localhost and TLS specified",
"http://api.arcade-ai.com",
id="force no TLS",
),
pytest.param(
{
"host_input": LOCALHOST,
"port_input": DEFAULT_PORT,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": True,
},
"http://localhost:9099",
id="localhost with force no TLS",
),
pytest.param(
{
"host_input": DEFAULT_HOST,
"port_input": 9099,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": True,
},
"http://api.arcade-ai.com:9099",
id="custom port with force no TLS",
),
pytest.param(
{
"host_input": LOCALHOST,
"port_input": 9099,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": True,
},
"http://localhost:9099",
id="localhost with custom port and force no TLS",
),
pytest.param(
{
"host_input": DEFAULT_HOST,
"port_input": DEFAULT_PORT,
"force_tls": True,
"force_no_tls": True,
},
"http://api.arcade-ai.com",
id="force TLS and no TLS",
),
pytest.param(
{
"host_input": LOCALHOST,
"port_input": DEFAULT_PORT,
"force_tls": True,
"force_no_tls": True,
},
"http://localhost:9099",
id="localhost with force TLS and no TLS",
),
pytest.param(
{
"host_input": DEFAULT_HOST,
"port_input": 9099,
"force_tls": True,
"force_no_tls": True,
},
"http://api.arcade-ai.com:9099",
id="custom port with force TLS and no TLS",
),
pytest.param(
{
"host_input": LOCALHOST,
"port_input": 9099,
"force_tls": True,
"force_no_tls": True,
},
"http://localhost:9099",
id="localhost with custom port, force TLS and no TLS",
),
pytest.param(
{
"host_input": "arandomhost.com",
"port_input": DEFAULT_PORT,
"force_tls": DEFAULT_FORCE_TLS,
"force_no_tls": DEFAULT_FORCE_NO_TLS,
},
"https://arandomhost.com",
id="random host",
),
],
)
def test_apply_config_overrides(inputs: dict, expected_outputs: dict):
# Set fake default values for testing
config = Config(
api=ApiConfig(key="fake_api_key"),
engine=EngineConfig(
host=DEFAULT_HOST,
port=DEFAULT_PORT,
tls=DEFAULT_TLS,
),
def test_compute_base_url(inputs: dict, expected_output: str):
base_url = compute_base_url(
inputs["force_tls"],
inputs["force_no_tls"],
inputs["host_input"],
inputs["port_input"],
)
apply_config_overrides(config, inputs["host_input"], inputs["port_input"], inputs["tls_input"])
assert config.engine.host == expected_outputs["host"]
assert config.engine.port == expected_outputs["port"]
assert config.engine.tls == expected_outputs["tls"]
assert base_url == expected_output