arcade login in CLI (#25)

Working now:
- `arcade login` works against the Cloud
- `arcade logout` deletes your local credentials

---------

Co-authored-by: Sam Partee <sam@arcade-ai.com>
This commit is contained in:
Nate Barbettini 2024-08-30 11:20:00 -07:00 committed by GitHub
parent aee706e118
commit 950e075750
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 607 additions and 551 deletions

View file

@ -26,8 +26,8 @@ build: clean-build ## Build wheel file using poetry
@echo "🚀 Creating wheel file"
@cd arcade && poetry build
.PHONY: clean
clean: ## clean build artifacts
.PHONY: clean-build
clean-build: ## clean build artifacts
@cd arcade && rm -rf dist
.PHONY: publish

135
arcade/arcade/cli/authn.py Normal file
View file

@ -0,0 +1,135 @@
import os
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from urllib.parse import parse_qs
import toml
from rich.console import Console
from arcade.cli.constants import LOGIN_FAILED_HTML, LOGIN_SUCCESS_HTML
console = Console()
class LoginCallbackHandler(BaseHTTPRequestHandler):
def __init__(self, *args, state: str, **kwargs): # type: ignore[no-untyped-def]
self.state = state # Simple CSRF protection
super().__init__(*args, **kwargs)
def log_message(self, format: str, *args: Any) -> None: # noqa: A002 Argument `format` is shadowing a Python builtin
# Override to suppress logging to stdout
pass
def _parse_login_response(self) -> tuple[str, str, str] | None:
# Parse the query string from the URL
query_string = self.path.split("?", 1)[-1]
params = parse_qs(query_string)
returned_state = params.get("state", [None])[0]
if returned_state != self.state:
console.print(
"❌ Login failed: Invalid login attempt. Please try again.", style="bold red"
)
return None
api_key = params.get("api_key", [None])[0] or ""
email = params.get("email", [None])[0] or ""
warning = params.get("warning", [None])[0] or ""
return api_key, email, warning
def _handle_login_response(self) -> bool:
result = self._parse_login_response()
if result is None:
return False
api_key, email, warning = result
if warning:
console.print(warning, style="bold yellow")
# If API key and email are received, store them in a file
if not api_key or not email:
console.print(
"❌ Login failed: No credentials received. Please try again.", style="bold red"
)
return False
# TODO don't overwrite existing config
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
new_config = {"api": {"key": api_key}, "user": {"email": email}}
with open(config_file_path, "w") as f:
toml.dump(new_config, f)
# Send a success response to the browser
console.print(
f"""✅ Hi there, {email}!
Your Arcade API key is: {api_key}
Stored in: {config_file_path}""",
style="bold green",
)
return True
def do_GET(self) -> None: # This naming is correct, required by BaseHTTPRequestHandler
success = self._handle_login_response()
if success:
self.send_response(200)
self.end_headers()
self.wfile.write(LOGIN_SUCCESS_HTML)
else:
self.send_response(400)
self.end_headers()
self.wfile.write(LOGIN_FAILED_HTML)
# Always shut down the server so it doesn't keep running
threading.Thread(target=self.server.shutdown).start()
class LocalAuthCallbackServer:
def __init__(self, state: str, port: int = 9905):
self.state = state
self.port = port
self.httpd: HTTPServer | None = None
def run_server(self) -> None:
# Initialize and run the server
server_address = ("", self.port)
handler = lambda *args, **kwargs: LoginCallbackHandler(*args, state=self.state, **kwargs)
self.httpd = HTTPServer(server_address, handler)
self.httpd.serve_forever()
def shutdown_server(self) -> None:
# Shut down the server gracefully
if self.httpd:
self.httpd.shutdown()
def check_existing_login() -> bool:
"""
Check if the user is already logged in by verifying the config file.
Returns:
bool: True if the user is already logged in, False otherwise.
"""
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
if not os.path.exists(config_file_path):
return False
try:
config: dict[str, Any] = toml.load(config_file_path)
api_key = config.get("api", {}).get("key")
email = config.get("user", {}).get("email")
if api_key and email:
console.print(
f"You're already logged in as {email}. "
f"Delete {config_file_path} to log in as a different user."
)
return True
except toml.TomlDecodeError:
console.print(f"Error: Invalid configuration file at {config_file_path}", style="bold red")
except Exception as e:
console.print(f"Error: Unable to read configuration file: {e!s}", style="bold red")
return False

View file

@ -0,0 +1,143 @@
_style_block = b"""
<link rel="icon" href="https://cdn.arcade-ai.com/favicons/favicon.ico" sizes="any">
<link rel="apple-touch-icon" href="https://cdn.arcade-ai.com/favicons/apple-touch-icon.png">
<link rel="icon" type="image/png" sizes="32x32" href="https://cdn.arcade-ai.com/favicons/favicon-32x32.png">
<link rel="icon" type="image/png" sizes="16x16" href="https://cdn.arcade-ai.com/favicons/favicon-16x16.png">
<link rel="apple-touch-icon" sizes="180x180" href="https://cdn.arcade-ai.com/favicons/apple-touch-icon.png">
<style>
body {
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: linear-gradient(135deg, #1a1a1a, #0f0f0f);
font-family: Arial, sans-serif;
}
.container {
background-color: #333;
padding: 40px;
border-radius: 8px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
width: 300px;
}
.container h2 {
color: #fff;
margin-bottom: 20px;
text-align: center;
}
.container label {
display: block;
color: #bbb;
margin-bottom: 5px;
font-size: 14px;
}
.container input[type="text"],
.container input[type="password"] {
width: 100%;
padding: 10px;
margin-bottom: 15px;
border: none;
border-radius: 4px;
background-color: #444;
color: #ddd;
font-size: 16px;
box-sizing: border-box;
}
.container input[type="text"]::placeholder,
.container input[type="password"]::placeholder {
color: #aaa;
}
.container input[type="submit"] {
width: 100%;
padding: 10px;
border: none;
border-radius: 4px;
background-color: #ED155D;
color: #fff;
font-size: 16px;
cursor: pointer;
transition: background-color 0.3s ease;
}
.container input[type="submit"]:hover {
background-color: #C0104A;
}
.message {
background-color: #1e1e1e;
padding: 10px;
border-radius: 4px;
margin-bottom: 15px;
font-size: 14px;
text-align: center;
}
.info {
color: #fff;
}
.error {
color: #ff4d4d;
}
.logo {
display: block;
max-width: 100%;
max-height: 90px;
margin: 0 auto 20px;
}
</style>
"""
LOGIN_SUCCESS_HTML = (
b"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Success!</title>
"""
+ _style_block
+ b"""
</head>
<body>
<div class="container">
<img src="https://cdn.arcade-ai.com/logos/a-icon.png" alt="Arcade logo" class="logo">
<h2>Log in to Arcade CLI</h2>
<p class="message info">Success! You can close this window.</p>
</div>
</body>
</html>
"""
)
LOGIN_FAILED_HTML = (
b"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Login failed</title>
"""
+ _style_block
+ b"""
</head>
<body>
<div class="container">
<img src="https://cdn.arcade-ai.com/logos/a-icon.png" alt="Arcade logo" class="logo">
<h2>Log in to Arcade CLI</h2>
<p class="message error">Something went wrong. Please close this window and try again.</p>
</div>
</body>
</html>
"""
)

View file

@ -1,46 +1,63 @@
import asyncio
import os
import threading
import uuid
import webbrowser
from typing import Any, Optional
from urllib.parse import urlencode
import typer
from openai.resources.chat.completions import ChatCompletionChunk, Stream
from rich.console import Console
from rich.markdown import Markdown
from rich.markup import escape
from rich.table import Table
from rich.text import Text
from typer.core import TyperGroup
from typer.models import Context
from arcade.core.catalog import ToolCatalog
from arcade.core.client import EngineClient
from arcade.core.config import Config
from arcade.core.schema import ToolCallOutput, ToolContext
from arcade.core.toolkit import Toolkit
from arcade.cli.authn import check_existing_login, LocalAuthCallbackServer
from arcade.cli.utils import (
OrderCommands,
create_cli_catalog,
display_streamed_markdown,
validate_and_get_config,
)
from arcade.client import Arcade
class OrderCommands(TyperGroup):
def list_commands(self, ctx: Context) -> list[str]: # type: ignore[override]
"""Return list of commands in the order appear."""
return list(self.commands) # get commands using self.commands
console = Console()
cli = typer.Typer(
cls=OrderCommands,
)
console = Console()
@cli.command(help="Log in to Arcade Cloud")
def login(
username: str = typer.Option(..., prompt="Username", help="Your Arcade Cloud username"),
api_key: str = typer.Option(None, prompt="API Key", help="Your Arcade Cloud API Key"),
) -> None:
def login() -> None:
"""
Logs the user into Arcade Cloud.
"""
# Here you would add the logic to authenticate the user with Arcade Cloud
raise NotImplementedError("This feature is not yet implemented.")
if check_existing_login():
return
# Start the HTTP server in a new thread
state = str(uuid.uuid4())
auth_server = LocalAuthCallbackServer(state)
server_thread = threading.Thread(target=auth_server.run_server)
server_thread.start()
try:
# Open the browser for user login
callback_uri = "http://localhost:9905/callback"
params = urlencode({"callback_uri": callback_uri, "state": state})
# TODO: make this configurable
login_url = f"http://localhost:8001/api/v1/auth/cli_login?{params}"
console.print("Opening a browser to log you in...")
webbrowser.open(login_url)
# Wait for the server thread to finish
server_thread.join()
except KeyboardInterrupt:
auth_server.shutdown_server()
finally:
if server_thread.is_alive():
server_thread.join() # Ensure the server thread completes and cleans up
@cli.command(help="Log out of Arcade Cloud")
@ -48,8 +65,14 @@ def logout() -> None:
"""
Logs the user out of Arcade Cloud.
"""
# Here you would add the logic to log the user out of Arcade Cloud
raise NotImplementedError("This feature is not yet implemented.")
# If ~/.arcade/arcade.toml exists, delete it
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
if os.path.exists(config_file_path):
os.remove(config_file_path)
console.print("You're now logged out.", style="bold")
else:
console.print("You're not logged in.", style="bold red")
@cli.command(help="Create a new toolkit package directory")
@ -100,135 +123,21 @@ def show(
console.print(error_message, style="bold red")
@cli.command(help="Run a tool using an LLM to predict the arguments")
def run(
toolkit: Optional[str] = typer.Option(
None, "-t", "--toolkit", help="The toolkit to include in the run"
),
model: str = typer.Option("gpt-4o", "-m", help="The model to use for prediction."),
tool: str = typer.Option(None, "--tool", help="The name of the tool to run."),
choice: str = typer.Option(
"generate", "-c", "--choice", help="The value of the tool choice argument"
),
stream: bool = typer.Option(
False, "-s", "--stream", is_flag=True, help="Stream the tool output."
),
prompt: str = typer.Argument(..., help="The prompt to use for context"),
) -> None:
"""
Run a tool using an LLM to predict the arguments.
"""
from arcade.core.client import EngineClient
from arcade.core.executor import ToolExecutor
try:
catalog = create_cli_catalog(toolkit=toolkit)
tools = [catalog[tool]] if tool else list(catalog)
config = Config.load_from_file()
if not config.engine or not config.engine_url:
console.print("❌ Engine configuration not found or URL is missing.", style="bold red")
typer.Exit(code=1)
if not config.api or not config.api.key:
console.print(
"❌ API configuration not found or key is missing. Please run `arcade login`.",
style="bold red",
)
typer.Exit(code=1)
client = EngineClient(api_key=config.api.key, base_url=config.engine_url)
# TODO better way of doing this
tool_choice = "auto" if choice in ["execute", "generate"] else choice
calls = client.call_tool(tools, tool_choice=tool_choice, prompt=prompt, model=model)
if len(calls) == 0:
console.print("[bold red]No tools were called[/bold red]")
messages = [
{"role": "user", "content": prompt},
]
for tool_name, parameters in calls:
called_tool = catalog[tool_name]
console.print(f"Calling tool: {tool_name} with params: {parameters}", style="bold blue")
# TODO async.gather instead of loop.
output: ToolCallOutput = asyncio.run(
ToolExecutor.run(
called_tool.tool,
called_tool.definition,
called_tool.input_model,
called_tool.output_model,
ToolContext(),
**parameters,
)
)
if output.error:
console.print(output.error.message, style="bold red")
typer.Exit(code=1)
else:
messages += [
{
"role": "assistant",
# TODO: escape the output and ensure serialization works
"content": f"Results of Tool {tool_name}: {output.value!s}",
},
]
if choice == "execute":
console.print(output.value, style="green")
raise typer.Exit(0)
else:
if stream:
stream_response = client.stream_complete(model=model, messages=messages)
display_streamed_markdown(stream_response)
else:
response = client.complete(model=model, messages=messages)
if not len(response.choices) and not response.choices[0].message.content:
console.print("No response from the tool.", style="bold red")
else:
console.print(Markdown(response.choices[0].message.content or ""))
except RuntimeError as e:
error_message = f"❌ Failed to run tool{': ' + escape(str(e)) if str(e) else ''}"
console.print(error_message, style="bold red")
@cli.command(help="Chat with a language model")
def chat(
model: str = typer.Option("gpt-4o", "-m", help="The model to use for prediction."),
stream: bool = typer.Option(
True, "-s", "--stream", is_flag=True, help="Stream the tool output."
False, "-s", "--stream", is_flag=True, help="Stream the tool output."
),
) -> None:
"""
Chat with a language model.
"""
config = validate_and_get_config()
config = Config.load_from_file()
if not config.engine or not config.engine_url:
console.print("❌ Engine configuration not found or URL is missing.", style="bold red")
typer.Exit(code=1)
if not config.api or not config.api.key:
console.print(
"❌ API configuration not found or key is missing. Please run `arcade login`.",
style="bold red",
)
typer.Exit(code=1)
client = EngineClient(api_key=config.api.key, base_url=config.engine_url)
if config.user and config.user.email:
user_email = config.user.email
user_attribution = f"({user_email})"
else:
console.print(
"❌ User email not found in configuration. Please run `arcade login`.", style="bold red"
)
typer.Exit(code=1)
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
user_email = config.user.email if config.user else None
user_attribution = f"({user_email})" if user_email else ""
try:
# start messages conversation
@ -237,7 +146,7 @@ def chat(
chat_header = Text.assemble(
"\n",
(
"======== Arcade AI Chat ========",
"=== Arcade AI Chat ===",
"bold magenta underline",
),
"\n",
@ -251,20 +160,24 @@ def chat(
messages.append({"role": "user", "content": user_input})
if stream:
stream_response = client.stream_complete(
# TODO Fix this in the client so users don't deal with these
# typing issues
stream_response = client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=messages,
tool_choice="generate",
user=user_email,
stream=True,
)
role, message = display_streamed_markdown(stream_response)
messages.append({"role": role, "content": message})
else:
response = client.complete(
response = client.chat.completions.create( # type: ignore[call-overload]
model=model,
messages=messages,
tool_choice="generate",
user=user_email,
stream=False,
)
message_content = response.choices[0].message.content or ""
role = response.choices[0].message.role
@ -310,30 +223,6 @@ def dev(
raise typer.Exit(code=1)
@cli.command(help="Manage the Arcade Engine (start/stop/restart)")
def engine(
action: str = typer.Argument("start", help="The action to take (start/stop/restart)"),
host: str = typer.Option("localhost", "--host", "-h", help="The host of the engine"),
port: int = typer.Option(6901, "--port", "-p", help="The port of the engine"),
) -> None:
"""
Manage the Arcade Engine (start/stop/restart)
"""
raise NotImplementedError("This feature is not yet implemented.")
@cli.command(help="Manage credientials stored in the Arcade Engine")
def credentials(
action: str = typer.Argument("show", help="The action to take (add/remove/show)"),
name: str = typer.Option(None, "--name", "-n", help="The name of the credential to add/remove"),
val: str = typer.Option(None, "--val", "-v", help="The value of the credential to add/remove"),
) -> None:
"""
Manage credientials stored in the Arcade Engine
"""
raise NotImplementedError("This feature is not yet implemented.")
@cli.command(help="Show/edit configuration details of the Arcade Engine")
def config(
action: str = typer.Argument("show", help="The action to take (show/edit)"),
@ -345,8 +234,7 @@ def config(
"""
Show/edit configuration details of the Arcade Engine
"""
config = Config.load_from_file()
config = validate_and_get_config()
if action == "show":
display_config_as_table(config)
@ -376,7 +264,7 @@ def config(
raise typer.Exit(code=1)
def display_config_as_table(config: Config) -> None:
def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
"""
Display the configuration details as a table using Rich library.
"""
@ -399,58 +287,3 @@ def display_config_as_table(config: Config) -> None:
table.add_row("", "", "")
console.print(table)
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, str]:
"""
Display the streamed markdown chunks as a single line.
"""
from rich.live import Live
full_message = ""
role = ""
with Live(console=console, refresh_per_second=10) as live:
for chunk in stream:
choice = chunk.choices[0]
chunk_message = choice.delta.content
if role == "":
role = choice.delta.role or ""
if role == "assistant":
console.print("\n[bold blue]Assistant:[/bold blue] ")
if chunk_message:
full_message += chunk_message
markdown_chunk = Markdown(full_message)
live.update(markdown_chunk)
return role, full_message
def create_cli_catalog(
toolkit: str | None = None,
show_toolkits: bool = False,
) -> ToolCatalog:
"""
Load toolkits from the python environment.
"""
if toolkit:
try:
prefixed_toolkit = "arcade_" + toolkit
toolkits = [Toolkit.from_package(prefixed_toolkit)]
except ValueError:
try: # try without prefix
toolkits = [Toolkit.from_package(toolkit)]
except ValueError as e:
console.print(f"{e}", style="bold red")
typer.Exit(code=1)
else:
toolkits = Toolkit.find_all_arcade_toolkits()
if not toolkits:
console.print("❌ No toolkits found or specified", style="bold red")
typer.Exit(code=1)
catalog = ToolCatalog()
for loaded_toolkit in toolkits:
if show_toolkits:
console.print(f"Loading toolkit: {loaded_toolkit.name}", style="bold blue")
catalog.add_toolkit(loaded_toolkit)
return catalog

View file

@ -13,7 +13,7 @@ console = Console()
DEFAULT_VERSIONS = {
"python": "^3.10",
"arcade-ai": f"^{VERSION}",
"pytest": "^7.4.0",
"pytest": "^8.3.0",
}

107
arcade/arcade/cli/utils.py Normal file
View file

@ -0,0 +1,107 @@
from typing import TYPE_CHECKING
import typer
from openai.resources.chat.completions import ChatCompletionChunk, Stream
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from typer.core import TyperGroup
from typer.models import Context
from arcade.core.catalog import ToolCatalog
from arcade.core.toolkit import Toolkit
if TYPE_CHECKING:
from arcade.core.config import Config
console = Console()
class OrderCommands(TyperGroup):
def list_commands(self, ctx: Context) -> list[str]: # type: ignore[override]
"""Return list of commands in the order appear."""
return list(self.commands) # get commands using self.commands
def create_cli_catalog(
toolkit: str | None = None,
show_toolkits: bool = False,
) -> ToolCatalog:
"""
Load toolkits from the python environment.
"""
if toolkit:
try:
prefixed_toolkit = "arcade_" + toolkit
toolkits = [Toolkit.from_package(prefixed_toolkit)]
except ValueError:
try: # try without prefix
toolkits = [Toolkit.from_package(toolkit)]
except ValueError as e:
console.print(f"{e}", style="bold red")
typer.Exit(code=1)
else:
toolkits = Toolkit.find_all_arcade_toolkits()
if not toolkits:
console.print("❌ No toolkits found or specified", style="bold red")
typer.Exit(code=1)
catalog = ToolCatalog()
for loaded_toolkit in toolkits:
if show_toolkits:
console.print(f"Loading toolkit: {loaded_toolkit.name}", style="bold blue")
catalog.add_toolkit(loaded_toolkit)
return catalog
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, str]:
"""
Display the streamed markdown chunks as a single line.
"""
full_message = ""
role = ""
with Live(console=console, refresh_per_second=10) as live:
for chunk in stream:
choice = chunk.choices[0]
chunk_message = choice.delta.content
if role == "":
role = choice.delta.role or ""
if role == "assistant":
console.print("\n[bold blue]Assistant:[/bold blue] ")
if chunk_message:
full_message += chunk_message
markdown_chunk = Markdown(full_message)
live.update(markdown_chunk)
return role, full_message
def validate_and_get_config(
validate_engine: bool = True,
validate_api: bool = True,
validate_user: bool = True,
) -> "Config":
"""
Validates the configuration, user, and returns the Config object
"""
from arcade.core.config import config
if validate_engine and (not config.engine or not config.engine_url):
console.print("❌ Engine configuration not found or URL is missing.", style="bold red")
raise typer.Exit(code=1)
if validate_api and (not config.api or not config.api.key):
console.print(
"❌ API configuration not found or key is missing. Please run `arcade login`.",
style="bold red",
)
raise typer.Exit(code=1)
if validate_user and (not config.user or not config.user.email):
console.print(
"❌ User email not found in configuration. Please run `arcade login`.", style="bold red"
)
raise typer.Exit(code=1)
return config

View file

@ -5,8 +5,6 @@ from urllib.parse import urljoin
import httpx
from httpx import Timeout
from arcade.core.config import config
T = TypeVar("T")
ResponseT = TypeVar("ResponseT")
@ -26,7 +24,6 @@ class BaseArcadeClient:
base_url: str,
api_key: str | None = None,
headers: dict[str, str] | None = None,
proxies: str | dict[str, str] | None = None,
timeout: float | Timeout = 10.0,
retries: int = 3,
):
@ -37,28 +34,20 @@ class BaseArcadeClient:
base_url: The base URL for the Arcade API.
api_key: The API key for authentication.
headers: Additional headers to include in requests.
proxies: Proxy configuration for requests.
timeout: Request timeout in seconds.
retries: Number of retries for failed requests.
"""
self._base_url = base_url
self._api_key = api_key or os.environ.get("ARCADE_API_KEY") or config.api.key
self._api_key = api_key or os.environ.get("ARCADE_API_KEY")
self._headers = headers or {}
self._headers.setdefault("Authorization", f"Bearer {self._api_key}")
self._headers.setdefault("Content-Type", "application/json")
self._proxies = proxies
self._timeout = timeout
self._retries = retries
def _build_url(self, path: str) -> str:
"""
Build the full URL for a given path.
Args:
path: The path to append to the base URL.
Returns:
The full URL.
"""
return urljoin(self._base_url, path)
@ -71,7 +60,6 @@ class SyncArcadeClient(BaseArcadeClient):
self._client = httpx.Client(
base_url=self._base_url,
headers=self._headers,
proxies=self._proxies,
timeout=self._timeout,
)
@ -116,7 +104,6 @@ class AsyncArcadeClient(BaseArcadeClient):
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers=self._headers,
proxies=self._proxies,
timeout=self._timeout,
)
return self._client

View file

@ -24,11 +24,14 @@ from arcade.core.schema import ToolDefinition
T = TypeVar("T")
ClientT = TypeVar("ClientT", SyncArcadeClient, AsyncArcadeClient)
API_VERSION = "v1"
BASE_URL = "https://api.arcade-ai.com"
class AuthResource(BaseResource[ClientT]):
"""Authentication resource."""
_base_path = "/v1/auth"
_base_path = f"/{API_VERSION}/auth"
def authorize(
self,
@ -66,11 +69,9 @@ class AuthResource(BaseResource[ClientT]):
return AuthResponse(**data)
def poll_authorization(self, auth_id: str) -> AuthResponse:
"""
Poll for the status of an authorization request.
"""Poll for the status of an authorization
Args:
auth_id: The authorization ID.
Polls using the authorization ID returned from the authorize method.
Example:
auth_status = client.auth.poll_authorization("auth_123")
@ -84,7 +85,7 @@ class AuthResource(BaseResource[ClientT]):
class ToolResource(BaseResource[ClientT]):
"""Tool resource."""
_base_path = "/v1/tools"
_base_path = f"/{API_VERSION}/tool"
def run(
self,
@ -116,10 +117,6 @@ class ToolResource(BaseResource[ClientT]):
def get(self, director_id: str, tool_id: str) -> ToolDefinition:
"""
Get the specification for a tool.
Args:
director_id: The director ID.
tool_id: The tool ID.
"""
data = self._client._execute_request( # type: ignore[attr-defined]
"GET",
@ -132,12 +129,10 @@ class ToolResource(BaseResource[ClientT]):
class ArcadeClientMixin(Generic[ClientT]):
"""Mixin for Arcade clients."""
def __init__(self, base_url: str, *args: Any, **kwargs: Any):
super().__init__(base_url, *args, **kwargs)
self._openai_client: OpenAI | AsyncOpenAI | None = None
def __init__(self, base_url: str = BASE_URL, *args: Any, **kwargs: Any):
super().__init__(base_url, *args, **kwargs) # type: ignore[call-arg]
self.auth: AuthResource = AuthResource(self)
self.tool: ToolResource = ToolResource(self)
self.chat: Chat | AsyncChat | None = None
def _handle_http_error(
self,
@ -148,25 +143,39 @@ class ArcadeClientMixin(Generic[ClientT]):
error_class = error_map.get(status_code, InternalServerError)
raise error_class(str(e), response=e.response)
def _chat_url(self, base_url: str) -> str:
# TODO (sam): make chat a Resource like others but maintain
# the ability to call chat directly like the openai clients
chat_url = str(base_url)
if not base_url.endswith(API_VERSION):
chat_url = f"{base_url}/{API_VERSION}"
return chat_url
class Arcade(ArcadeClientMixin[SyncArcadeClient], SyncArcadeClient):
"""Synchronous Arcade client."""
"""Synchronous Arcade client.
Example:
from arcade.client import Arcade
client = Arcade(api_key="your-api-key")
client.auth.authorize(...)
client.tool.run(...)
"""
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
# Assume we are using the LLM API of the Engine for now
self._openai_client = OpenAI(base_url=self._base_url + "/v1", api_key=self._api_key)
self.chat = self._openai_client.chat
chat_url = self._chat_url(self._base_url)
self._openai_client = OpenAI(base_url=chat_url, api_key=self._api_key)
@property
def chat(self) -> Chat:
return self._openai_client.chat
def _execute_request(self, method: str, url: str, **kwargs: Any) -> Any:
"""
Execute a synchronous request.
Args:
method: The HTTP method.
url: The URL to request.
**kwargs: Additional arguments for the request.
"""
try:
response = self._request(method, url, **kwargs)
@ -189,17 +198,17 @@ class AsyncArcade(ArcadeClientMixin[AsyncArcadeClient], AsyncArcadeClient):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._openai_client = AsyncOpenAI(base_url=self._base_url + "/v1")
self.chat = self._openai_client.chat
chat_url = self._chat_url(self._base_url)
self._openai_client = AsyncOpenAI(base_url=chat_url, api_key=self._api_key)
@property
def chat(self) -> AsyncChat:
return self._openai_client.chat
async def _execute_request(self, method: str, url: str, **kwargs: Any) -> Any:
"""
Execute an asynchronous request.
Args:
method: The HTTP method.
url: The URL to request.
**kwargs: Additional arguments for the request.
"""
try:
response = await self._request(method, url, **kwargs)

View file

@ -24,9 +24,10 @@ class AuthProvider(str, Enum):
class AuthRequest(BaseModel):
"""
The requirements for authorization for a tool
# TODO (Nate): Make a validator here
"""
authority: AnyUrl | None = None
authority: AnyUrl | str | None = None
"""The URL of the OAuth 2.0 authorization server."""
scope: list[str]

View file

@ -1,187 +0,0 @@
import json
from enum import Enum
from typing import Any, Optional
from openai import OpenAI
from openai.resources.chat.completions import ChatCompletion, ChatCompletionChunk, Stream
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from arcade.core.catalog import MaterializedTool
PYTHON_TO_JSON_TYPES: dict[type, str] = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
list: "array",
dict: "object",
}
ToolCalls = dict[str, dict[str, Any]]
def python_type_to_json_type(python_type: type[Any]) -> dict[str, Any] | str:
"""
Map Python types to JSON Schema types, including handling of
complex types such as lists and dictionaries.
"""
if hasattr(python_type, "__origin__"):
origin = python_type.__origin__
if origin is list:
item_type = python_type_to_json_type(python_type.__args__[0])
return {"type": "array", "items": item_type}
elif origin is dict:
value_type = python_type_to_json_type(python_type.__args__[1])
return {"type": "object", "additionalProperties": value_type}
elif issubclass(python_type, BaseModel):
return model_to_json_schema(python_type)
return PYTHON_TO_JSON_TYPES.get(python_type, "string")
def model_to_json_schema(model: type[BaseModel]) -> dict[str, Any]:
"""
Convert a Pydantic model to a JSON schema.
"""
properties = {}
required = []
for field_name, model_field in model.model_fields.items():
type_json = python_type_to_json_type(model_field.annotation) # type: ignore[arg-type]
if isinstance(type_json, dict):
field_schema = type_json
else:
field_schema = {
"type": type_json,
"description": model_field.description or "",
}
if model_field.default not in [None, PydanticUndefined]:
if isinstance(model_field.default, Enum):
field_schema["default"] = model_field.default.value
else:
field_schema["default"] = model_field.default
if model_field.is_required():
required.append(field_name)
properties[field_name] = field_schema
return {
"type": "object",
"properties": properties,
"required": required,
}
def schema_to_openai_tool(tool: MaterializedTool) -> dict[str, Any]:
"""
Convert a ToolDefinition object to a JSON schema dictionary in the specified function format.
"""
input_model_schema = model_to_json_schema(tool.input_model)
function_schema = {
"type": "function",
"function": {
"name": tool.definition.name,
"description": tool.definition.description,
"parameters": input_model_schema,
},
}
return function_schema
def called_tool(chat_completion: ChatCompletion) -> bool:
"""
Return true if the chat completion called a tool.
"""
choice = chat_completion.choices[0]
if choice.message.tool_calls:
return True
return False
def get_tool_args(chat_completion: ChatCompletion) -> list[tuple[str, dict[str, Any]]]:
"""
Returns the tool arguments from the chat completion object.
"""
tool_args_list = []
message = chat_completion.choices[0].message
if message.tool_calls:
for tool_call in message.tool_calls:
tool_args_list.append(
(
tool_call.function.name,
json.loads(tool_call.function.arguments),
)
)
return tool_args_list
class EngineClient:
def __init__(self, api_key: str, base_url: str | None = None):
self.client = OpenAI(api_key=api_key, base_url=base_url)
def __getattr__(self, name: str) -> Any:
return getattr(self.client, name)
def call_tool(
self,
tools: list[MaterializedTool],
model: str,
messages: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "required",
parallel_tool_calls: Optional[bool] = True,
prompt: Optional[str] = "",
**kwargs: Any,
) -> list[tuple[str, dict[str, Any]]]:
"""
Infer the arguments for a given tool and call the OpenAI API.
"""
specs = [schema_to_openai_tool(tool) for tool in tools]
if messages is None:
messages = [{"role": "user", "content": prompt}]
try:
completion = self.complete(
model=model,
messages=messages,
tools=specs,
tool_choice=tool_choice,
parallel_tool_calls=parallel_tool_calls,
**kwargs,
)
if not called_tool(completion):
raise ValueError("No tool call was made.")
except (KeyError, IndexError) as e:
raise ValueError("Invalid response format from OpenAI API.") from e
return get_tool_args(completion)
def complete(
self,
model: str,
messages: list[dict[str, Any]],
**kwargs: Any,
) -> ChatCompletion:
"""
Call the OpenAI API with the given messages.
"""
completion = self.client.chat.completions.create(
model=model,
messages=messages, # type: ignore[arg-type]
**kwargs,
)
return completion
def stream_complete( # type: ignore[misc]
self,
model: str,
messages: list[dict[str, Any]],
**kwargs: Any,
) -> Stream[ChatCompletionChunk]:
stream = self.client.chat.completions.create(
model=model,
messages=messages, # type: ignore[arg-type]
stream=True,
**kwargs,
)
yield from stream

View file

@ -1,5 +1,9 @@
import ipaddress
from functools import cached_property, lru_cache
from pathlib import Path
from urllib.parse import urlparse
import idna
import toml
from pydantic import BaseModel, ValidationError
@ -33,15 +37,15 @@ class EngineConfig(BaseModel):
Arcade Engine configuration.
"""
host: str = "localhost"
host: str = "api.arcade-ai.com"
"""
Arcade Engine host.
"""
port: int = 6901
port: int | None = None
"""
Arcade Engine port.
"""
tls: bool = False
tls: bool = True
"""
Whether to use TLS for the connection to Arcade Engine.
"""
@ -60,7 +64,7 @@ class Config(BaseModel):
"""
Arcade user configuration.
"""
engine: EngineConfig | None = None
engine: EngineConfig | None = EngineConfig()
"""
Arcade Engine configuration.
"""
@ -79,15 +83,77 @@ class Config(BaseModel):
"""
return cls.get_config_dir_path() / "arcade.toml"
@property
@cached_property
def engine_url(self) -> str:
"""
Get the URL of the Arcade Engine.
Get the cached URL of the Arcade Engine.
This property is cached after its first access to improve performance.
The cache is automatically invalidated if any of the underlying data changes.
The port is included in the URL unless the host is a fully qualified domain name
(excluding IP addresses) and no port is specified. Handles IPv4, IPv6, IDNs, and
hostnames with underscores.
This property exists to provide a consistent and correctly formatted URL for
connecting to the Arcade Engine, taking into account various configuration
options and edge cases. It ensures that:
1. The correct protocol (http/https) is used based on the TLS setting.
2. IPv4 and IPv6 addresses are properly formatted.
3. Internationalized Domain Names (IDNs) are correctly encoded.
4. Fully Qualified Domain Names (FQDNs) are identified and handled appropriately.
5. Ports are included when necessary, respecting common conventions for FQDNs.
6. Hostnames with underscores (common in development environments) are supported.
7. Pre-existing port specifications in the host are respected.
The resulting URL is always suffixed with '/v1' to specify the API version.
Returns:
str: The fully constructed URL for the Arcade Engine.
Raises:
ValueError: If the engine configuration is missing or incomplete.
"""
if self.engine is None:
raise ValueError("Engine not set")
raise ValueError("Configuration for Engine is not set in arcade.toml")
if not self.engine.host:
raise ValueError("Configuration for Engine host is not set in arcade.toml")
protocol = "https" if self.engine.tls else "http"
return f"{protocol}://{self.engine.host}:{self.engine.port}/v1"
# Handle potential IDNs
try:
encoded_host = idna.encode(self.engine.host).decode("ascii")
except idna.IDNAError:
encoded_host = self.engine.host
# Check if the host is a valid IP address (IPv4 or IPv6)
try:
ipaddress.ip_address(encoded_host)
is_ip = True
except ValueError:
is_ip = False
# Parse the host, handling potential IPv6 addresses
host_for_parsing = f"[{encoded_host}]" if is_ip and ":" in encoded_host else encoded_host
parsed_host = urlparse(f"//{host_for_parsing}")
# Check if the host is a fully qualified domain name (excluding IP addresses)
is_fqdn = "." in parsed_host.netloc and not is_ip and "_" not in parsed_host.netloc
# Handle hosts that might already include a port
if ":" in parsed_host.netloc and not is_ip:
host, existing_port = parsed_host.netloc.rsplit(":", 1)
if existing_port.isdigit():
return f"{protocol}://{parsed_host.netloc}/v1"
if is_fqdn and self.engine.port is None:
return f"{protocol}://{encoded_host}/v1"
elif self.engine.port is not None:
return f"{protocol}://{encoded_host}:{self.engine.port}/v1"
else:
return f"{protocol}://{encoded_host}/v1"
@classmethod
def ensure_config_dir_exists(cls) -> None:
@ -102,7 +168,22 @@ class Config(BaseModel):
def load_from_file(cls) -> "Config":
"""
Load the configuration from the TOML file in the configuration directory.
If no configuration file exists, create a new one with default values.
If no configuration file exists, this method will create a new one with default values.
The default configuration includes:
- An empty API configuration
- A default Engine configuration (host: "api.arcade-ai.com", port: None, tls: True)
- No user configuration
This behavior ensures that the application always has a valid configuration to work with,
but it may not be suitable for all use cases. If a specific configuration is required,
ensure that the configuration file exists before calling this method.
Returns:
Config: The loaded or newly created configuration.
Raises:
ValueError: If the existing configuration file is invalid.
"""
cls.ensure_config_dir_exists()
@ -149,5 +230,21 @@ class Config(BaseModel):
config_file_path.write_text(toml.dumps(self.model_dump()))
# Singleton instance of Config
config = Config.load_from_file()
@lru_cache(maxsize=1)
def get_config() -> Config:
"""
Get the Arcade configuration.
This function is cached, so subsequent calls will return the same Config object
without reloading from the file, unless the cache is cleared.
remember to clear the cache if the configuration file is modified.
use `get_config.cache_clear()` to clear the cache.
Returns:
Config: The Arcade configuration.
"""
return Config.load_from_file()
config = get_config()

View file

@ -1,21 +1,17 @@
from arcade_github.tools import repo, user
from arcade_gmail.tools import gmail
from arcade_slack.tools import chat
from fastapi import FastAPI, HTTPException
from openai import AsyncOpenAI
from pydantic import BaseModel
from arcade_gmail.tools import gmail
from arcade_github.tools import repo, user
from arcade_slack.tools import chat
from arcade.core.config import config
from arcade.actor.fastapi.actor import FastAPIActor
from arcade.client import AsyncArcade
from arcade.core.config import config
if not config.api or not config.api.key:
raise ValueError("Arcade API key not set. Please run `arcade login`.")
client = AsyncOpenAI(
api_key=config.api.key,
base_url="http://localhost:9099/v1",
)
client = AsyncArcade(api_key=config.api.key)
app = FastAPI()

View file

@ -30,7 +30,7 @@ from langgraph.prebuilt import create_react_agent
# Step 3 (Option 2) Use the Arcade SDK to authenticate with Gmail
from arcade.client import Arcade, AuthProvider
client = Arcade(base_url="http://localhost:9099", api_key=os.environ["ARCADE_API_KEY"])
client = Arcade(api_key=os.environ["ARCADE_API_KEY"])
challenge = client.auth.authorize(
provider=AuthProvider.google,

View file

@ -1,65 +0,0 @@
import os
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.exceptions import RefreshError
from googleapiclient.discovery import build
from typing import Annotated
from arcade.sdk import tool
SECRET_FILE = "/Users/spartee/Dropbox/Arcade/gcp/credentials.json"
DRIVE_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
@tool
async def list_drive_files(
n_files: Annotated[int, "Number of files to search"] = 5,
) -> list[str]:
"""List files from a Google Drive account and return their details."""
creds = None
# The file token.json stores the user's access and refresh tokens, and is
# created automatically when the authorization flow completes for the first time.
# TODO: use context.authorization.token like gmail.py
if os.path.exists("token.json"):
creds = Credentials.from_authorized_user_file("token.json")
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
except RefreshError:
flow = InstalledAppFlow.from_client_secrets_file(
SECRET_FILE, DRIVE_SCOPES
)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open("token.json", "w") as token:
token.write(creds.to_json())
else:
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, DRIVE_SCOPES)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open("token.json", "w") as token:
token.write(creds.to_json())
# Call the Drive v3 API
service = build("drive", "v3", credentials=creds)
# Request a list of all the files
results = (
service.files()
.list(pageSize=n_files, fields="nextPageToken, files(id, name)")
.execute()
)
items = results.get("files", [])
if not items:
print("No files found.")
else:
print("Files:")
for item in items:
print("{0} ({1})".format(item["name"], item["id"]))
return items