Client Fixes and LangGraph Examples (#50)
This PR includes several improvements to the Arcade client and adds LangGraph examples: 1. Enhanced error handling in the Arcade client: - Improved HTTP error handling in `BaseArcadeClient` - Simplified request methods in `SyncArcadeClient` and `AsyncArcadeClient` 2. Updated `ToolResource` class: - Changed base path from `/v1/tool` to `/v1/tools` - Added `tool_version` parameter to `authorize` method 3. Improved Toolkit discovery: - Updated `find_all_arcade_toolkits` to search only in the current Python interpreter's site-packages 5. Added LangGraph examples: - New `langgraph_auth.py` example demonstrating Gmail authentication - New `langgraph_with_tool_exec.py` example showing tool execution within a LangGraph 6. Minor updates: - Changed default `BASE_URL` to `https://api.arcade.com/` - Updated import error message for eval dependencies --------- Co-authored-by: Nate Barbettini <nate@arcade-ai.com>
This commit is contained in:
parent
8d66b52512
commit
2eb46a3a98
32 changed files with 1291 additions and 403 deletions
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
|
|
@ -34,15 +34,15 @@
|
|||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"name": "Debug `arcade evals -d`",
|
||||
"name": "Debug `arcade evals -d` on current file",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/arcade/run_cli.py",
|
||||
"args": ["evals", "-d"],
|
||||
"args": ["evals", "-d", "${fileDirname}", "-h", "localhost"],
|
||||
"console": "integratedTerminal",
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
"cwd": "${workspaceFolder}"
|
||||
"cwd": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class BaseActor(Actor):
|
|||
"""
|
||||
return [tool.definition for tool in self.catalog]
|
||||
|
||||
def register_tool(self, tool: Callable, toolkit_name: str | None = None) -> None:
|
||||
def register_tool(self, tool: Callable, toolkit_name: str) -> None:
|
||||
"""
|
||||
Register a tool to the catalog.
|
||||
"""
|
||||
|
|
|
|||
378
arcade/arcade/cli/launcher.py
Normal file
378
arcade/arcade/cli/launcher.py
Normal file
|
|
@ -0,0 +1,378 @@
|
|||
import io
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
console = Console(highlight=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_servers(
|
||||
host: str,
|
||||
port: int,
|
||||
engine_config: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Start the actor and engine servers.
|
||||
|
||||
Args:
|
||||
host: Host for the actor server.
|
||||
port: Port for the actor server.
|
||||
engine_config: Path to the engine configuration file.
|
||||
"""
|
||||
# Validate host and port
|
||||
host = _validate_host(host)
|
||||
port = _validate_port(port)
|
||||
|
||||
# Ensure engine_config is provided and validated
|
||||
engine_config = _get_engine_config(engine_config)
|
||||
|
||||
# Prepare command-line arguments for the actor server and engine
|
||||
actor_cmd = _build_actor_command(host, port)
|
||||
engine_cmd = _build_engine_command(engine_config)
|
||||
|
||||
# Start and manage the processes
|
||||
_manage_processes(actor_cmd, engine_cmd)
|
||||
|
||||
|
||||
def _validate_host(host: str) -> str:
|
||||
"""
|
||||
Validates the host input.
|
||||
|
||||
Args:
|
||||
host: Host for the actor server.
|
||||
|
||||
Returns:
|
||||
The validated host as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the host is invalid.
|
||||
"""
|
||||
try:
|
||||
# Validate IP address
|
||||
ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
# Optionally, validate hostname
|
||||
if not host.isalnum() and "-" not in host and "." not in host:
|
||||
console.print(f"❌ Invalid host: {host}", style="bold red")
|
||||
raise ValueError("Invalid host.")
|
||||
return host
|
||||
|
||||
|
||||
def _validate_port(port: int) -> int:
|
||||
"""
|
||||
Validates the port input.
|
||||
|
||||
Args:
|
||||
port: Port for the actor server.
|
||||
|
||||
Returns:
|
||||
The validated port as an integer.
|
||||
|
||||
Raises:
|
||||
ValueError: If the port is out of the valid range.
|
||||
"""
|
||||
if not (1 <= port <= 65535):
|
||||
console.print(f"❌ Invalid port: {port}", style="bold red")
|
||||
raise ValueError("Invalid port.")
|
||||
return port
|
||||
|
||||
|
||||
def _get_engine_config(engine_config: str | None) -> str:
|
||||
"""
|
||||
Determines and validates the engine config file path.
|
||||
|
||||
Args:
|
||||
engine_config: Optional path provided by the user.
|
||||
|
||||
Returns:
|
||||
The resolved engine config file path.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the config file is not found or invalid.
|
||||
"""
|
||||
if engine_config:
|
||||
engine_config_path = Path(os.path.expanduser(engine_config)).resolve()
|
||||
if not engine_config_path.is_file():
|
||||
console.print(
|
||||
f"❌ Engine config file not found at {engine_config_path}", style="bold red"
|
||||
)
|
||||
raise RuntimeError("Engine config file not found.")
|
||||
else:
|
||||
# Look for engine.yaml in the current directory
|
||||
engine_config_path = Path(os.getcwd()) / "engine.yaml"
|
||||
if not engine_config_path.is_file():
|
||||
console.print(
|
||||
"❌ Engine config file not specified and not found in current directory.",
|
||||
style="bold red",
|
||||
)
|
||||
raise RuntimeError("Engine config file not specified.")
|
||||
return str(engine_config_path)
|
||||
|
||||
|
||||
def _build_actor_command(host: str, port: int) -> list[str]:
|
||||
"""
|
||||
Builds the command to start the actor server.
|
||||
|
||||
Args:
|
||||
host: Host for the actor server.
|
||||
port: Port for the actor server.
|
||||
|
||||
Returns:
|
||||
The command as a list.
|
||||
"""
|
||||
# Expand full path to "arcade" executable
|
||||
arcade_bin = shutil.which("arcade")
|
||||
if not arcade_bin:
|
||||
console.print(
|
||||
"❌ Arcade binary not found, please install with `pip install arcade-ai`",
|
||||
style="bold red",
|
||||
)
|
||||
sys.exit(1)
|
||||
cmd = [
|
||||
arcade_bin,
|
||||
"dev",
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
]
|
||||
return cmd
|
||||
|
||||
|
||||
def _build_engine_command(engine_config: str) -> list[str]:
|
||||
"""
|
||||
Builds the command to start the engine.
|
||||
|
||||
Args:
|
||||
engine_config: Path to the engine configuration file.
|
||||
|
||||
Returns:
|
||||
The command as a list.
|
||||
"""
|
||||
engine_bin = shutil.which("engine")
|
||||
if not engine_bin:
|
||||
console.print(
|
||||
"❌ Engine binary not found, refer to the installation guide at "
|
||||
"https://docs.arcade-ai.com/docs/home/deployment for how to install the engine",
|
||||
style="bold red",
|
||||
)
|
||||
sys.exit(1)
|
||||
cmd = [
|
||||
engine_bin,
|
||||
"dev",
|
||||
"-c",
|
||||
engine_config,
|
||||
]
|
||||
return cmd
|
||||
|
||||
|
||||
def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
||||
"""
|
||||
Manages the lifecycle of the actor and engine processes.
|
||||
|
||||
Args:
|
||||
actor_cmd: The command to start the actor server.
|
||||
engine_cmd: The command to start the engine.
|
||||
"""
|
||||
actor_process: subprocess.Popen | None = None
|
||||
engine_process: subprocess.Popen | None = None
|
||||
|
||||
def terminate_processes(exit_program: bool = False) -> None:
|
||||
console.print("Terminating child processes...", style="bold yellow")
|
||||
_terminate_process(actor_process)
|
||||
_terminate_process(engine_process)
|
||||
if exit_program:
|
||||
sys.exit(0)
|
||||
|
||||
_setup_signal_handlers(terminate_processes)
|
||||
|
||||
retry_count = 0
|
||||
max_retries = 3 # Define the maximum number of retries
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Start the actor server
|
||||
console.print("Starting actor server...", style="bold green")
|
||||
actor_process = _start_process("Actor", actor_cmd)
|
||||
|
||||
# Wait a bit to ensure actor is up
|
||||
time.sleep(2)
|
||||
|
||||
# Start the engine
|
||||
console.print("Starting engine...", style="bold green")
|
||||
engine_process = _start_process("Engine", engine_cmd)
|
||||
|
||||
# Monitor processes
|
||||
_monitor_processes(actor_process, engine_process)
|
||||
|
||||
# If we reach here, one of the processes has exited
|
||||
retry_count += 1
|
||||
console.print(
|
||||
f"Processes exited. Retry {retry_count} of {max_retries}.", style="bold yellow"
|
||||
)
|
||||
|
||||
if retry_count > max_retries:
|
||||
console.print(f"❌ Exiting after {retry_count - 1} retries", style="bold red")
|
||||
terminate_processes(exit_program=True)
|
||||
break # Exit the loop
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"❌ Exception occurred: {e}", style="bold red")
|
||||
terminate_processes()
|
||||
retry_count += 1
|
||||
if retry_count > max_retries:
|
||||
console.print(
|
||||
f"❌ Exiting after {retry_count - 1} retries due to exceptions",
|
||||
style="bold red",
|
||||
)
|
||||
sys.exit(1)
|
||||
break # Not strictly necessary, but good practice
|
||||
|
||||
console.print("Exiting...", style="bold red")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _start_process(name: str, cmd: list[str]) -> subprocess.Popen:
|
||||
"""
|
||||
Starts a subprocess and begins streaming its output.
|
||||
|
||||
Args:
|
||||
name: Name of the process.
|
||||
cmd: Command to execute.
|
||||
|
||||
Returns:
|
||||
The subprocess.Popen object.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the process fails to start.
|
||||
"""
|
||||
try:
|
||||
process = subprocess.Popen( # noqa: S603, RUF100
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
shell=False,
|
||||
)
|
||||
_stream_output(process, name)
|
||||
return process # noqa: TRY300
|
||||
except Exception as e:
|
||||
console.print(f"❌ Failed to start {name}: {e}", style="bold red")
|
||||
raise RuntimeError(f"Failed to start {name}")
|
||||
|
||||
|
||||
def _stream_output(process: subprocess.Popen, name: str) -> None:
|
||||
"""
|
||||
Streams the output from a subprocess to the console.
|
||||
|
||||
Args:
|
||||
process: The subprocess.Popen object.
|
||||
name: Name of the process.
|
||||
"""
|
||||
stdout_style = "green" if name == "Actor" else "#87CEFA"
|
||||
|
||||
def stream(pipe: io.TextIOWrapper | None, style: str) -> None:
|
||||
if pipe is None:
|
||||
return
|
||||
with pipe:
|
||||
for line in iter(pipe.readline, ""):
|
||||
console.print(f"[{style}]{name}>[/{style}] {line.rstrip()}")
|
||||
|
||||
threading.Thread(target=stream, args=(process.stdout, stdout_style), daemon=True).start()
|
||||
threading.Thread(target=stream, args=(process.stderr, "red"), daemon=True).start()
|
||||
|
||||
|
||||
def _monitor_processes(actor_process: subprocess.Popen, engine_process: subprocess.Popen) -> None:
|
||||
"""
|
||||
Monitors the actor and engine processes, restarts them if they exit.
|
||||
|
||||
Args:
|
||||
actor_process: The actor subprocess.
|
||||
engine_process: The engine subprocess.
|
||||
"""
|
||||
while True:
|
||||
actor_status = actor_process.poll()
|
||||
engine_status = engine_process.poll()
|
||||
|
||||
if actor_status is not None or engine_status is not None:
|
||||
if actor_status is not None:
|
||||
console.print(
|
||||
f"Actor process exited with code {actor_status}. Restarting both processes...",
|
||||
style="bold red",
|
||||
)
|
||||
if engine_status is not None:
|
||||
console.print(
|
||||
f"Engine process exited with code {engine_status}. Restarting both processes...",
|
||||
style="bold red",
|
||||
)
|
||||
_terminate_process(actor_process)
|
||||
_terminate_process(engine_process)
|
||||
time.sleep(1)
|
||||
break # Exit to restart both processes
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _terminate_process(process: subprocess.Popen | None) -> None:
|
||||
"""
|
||||
Terminates a subprocess if it's running.
|
||||
|
||||
Args:
|
||||
process: The subprocess.Popen object.
|
||||
"""
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
|
||||
def _setup_signal_handlers(terminate_processes: Callable[[bool], None]) -> None:
|
||||
"""
|
||||
Setup signal handlers to handle process termination signals.
|
||||
|
||||
Args:
|
||||
terminate_processes: Function to call to terminate child processes.
|
||||
"""
|
||||
signals_to_handle = ["SIGINT", "SIGTERM", "SIGQUIT", "SIGHUP"]
|
||||
|
||||
for sig_name in signals_to_handle:
|
||||
sig = getattr(signal, sig_name, None)
|
||||
if sig is None:
|
||||
continue # Signal not available on this platform
|
||||
try:
|
||||
# Use a lambda to pass the terminate_processes function
|
||||
signal.signal(
|
||||
sig,
|
||||
lambda signum, frame: _handle_signal(signum, terminate_processes),
|
||||
)
|
||||
except (ValueError, RuntimeError):
|
||||
# Signal handling not allowed in this thread or invalid signal
|
||||
console.print(f"Warning: Cannot set handler for {sig_name}", style="bold yellow")
|
||||
continue
|
||||
|
||||
|
||||
def _handle_signal(signum: int, terminate_processes: Callable[[bool], None]) -> None:
|
||||
"""
|
||||
Handle received signal and terminate child processes.
|
||||
|
||||
Args:
|
||||
signum: The signal number received.
|
||||
terminate_processes: Function to call to terminate child processes.
|
||||
"""
|
||||
signal_name = signal.Signals(signum).name
|
||||
console.print(f"Received {signal_name}. Shutting down...", style="bold yellow")
|
||||
terminate_processes(exit_program=True) # type: ignore[call-arg]
|
||||
|
|
@ -15,6 +15,7 @@ from rich.table import Table
|
|||
from rich.text import Text
|
||||
|
||||
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
|
||||
from arcade.cli.launcher import start_servers
|
||||
from arcade.cli.utils import (
|
||||
OrderCommands,
|
||||
apply_config_overrides,
|
||||
|
|
@ -28,14 +29,41 @@ from arcade.cli.utils import (
|
|||
)
|
||||
from arcade.client import Arcade
|
||||
from arcade.client.errors import EngineNotHealthyError, EngineOfflineError
|
||||
from arcade.core.config_model import Config
|
||||
|
||||
cli = typer.Typer(
|
||||
cls=OrderCommands,
|
||||
add_completion=False,
|
||||
no_args_is_help=True,
|
||||
pretty_exceptions_enable=False,
|
||||
pretty_exceptions_show_locals=False,
|
||||
pretty_exceptions_short=True,
|
||||
)
|
||||
console = Console()
|
||||
|
||||
|
||||
@cli.command(help="Log in to Arcade Cloud")
|
||||
def _get_config_with_overrides(
|
||||
force_tls: bool,
|
||||
force_no_tls: bool,
|
||||
host_input: str | None = None,
|
||||
port_input: int | None = None,
|
||||
) -> Config:
|
||||
"""
|
||||
Get the config with CLI-specific optional overrides applied.
|
||||
"""
|
||||
config = validate_and_get_config()
|
||||
|
||||
if not force_tls and not force_no_tls:
|
||||
tls_input = None
|
||||
elif force_no_tls:
|
||||
tls_input = False
|
||||
else:
|
||||
tls_input = True
|
||||
apply_config_overrides(config, host_input, port_input, tls_input)
|
||||
return config
|
||||
|
||||
|
||||
@cli.command(help="Log in to Arcade Cloud", rich_help_panel="User")
|
||||
def login(
|
||||
host: str = typer.Option(
|
||||
"cloud.arcade-ai.com",
|
||||
|
|
@ -74,7 +102,7 @@ def login(
|
|||
server_thread.join() # Ensure the server thread completes and cleans up
|
||||
|
||||
|
||||
@cli.command(help="Log out of Arcade Cloud")
|
||||
@cli.command(help="Log out of Arcade Cloud", rich_help_panel="User")
|
||||
def logout() -> None:
|
||||
"""
|
||||
Logs the user out of Arcade Cloud.
|
||||
|
|
@ -89,7 +117,7 @@ def logout() -> None:
|
|||
console.print("You're not logged in.", style="bold red")
|
||||
|
||||
|
||||
@cli.command(help="Create a new toolkit package directory")
|
||||
@cli.command(help="Create a new toolkit package directory", rich_help_panel="Tool Development")
|
||||
def new(
|
||||
directory: str = typer.Option(os.getcwd(), "--dir", help="tools directory path"),
|
||||
) -> None:
|
||||
|
|
@ -105,7 +133,10 @@ def new(
|
|||
console.print(error_message, style="bold red")
|
||||
|
||||
|
||||
@cli.command(help="Show the available tools in an actor or toolkit directory")
|
||||
@cli.command(
|
||||
help="Show the installed toolkits",
|
||||
rich_help_panel="Tool Development",
|
||||
)
|
||||
def show(
|
||||
toolkit: Optional[str] = typer.Option(
|
||||
None, "-t", "--toolkit", help="The toolkit to show the tools of"
|
||||
|
|
@ -139,12 +170,13 @@ def show(
|
|||
console.print(error_message, style="bold red")
|
||||
|
||||
|
||||
@cli.command(help="Chat with a language model")
|
||||
@cli.command(help="Start Arcade Chat in the terminal", rich_help_panel="Launch")
|
||||
def chat(
|
||||
model: str = typer.Option("gpt-4o", "-m", help="The model to use for prediction."),
|
||||
stream: bool = typer.Option(
|
||||
False, "-s", "--stream", is_flag=True, help="Stream the tool output."
|
||||
),
|
||||
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
|
||||
host: str = typer.Option(
|
||||
None,
|
||||
"-h",
|
||||
|
|
@ -167,20 +199,11 @@ def chat(
|
|||
"--no-tls",
|
||||
help="Whether to disable TLS for the connection to the Arcade Engine.",
|
||||
),
|
||||
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
|
||||
) -> None:
|
||||
"""
|
||||
Chat with a language model.
|
||||
"""
|
||||
config = validate_and_get_config()
|
||||
|
||||
if not force_tls and not force_no_tls:
|
||||
tls_input = None
|
||||
elif force_no_tls:
|
||||
tls_input = False
|
||||
else:
|
||||
tls_input = True
|
||||
apply_config_overrides(config, host, port, tls_input)
|
||||
config = _get_config_with_overrides(force_tls, force_no_tls, host, port)
|
||||
|
||||
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
|
||||
user_email = config.user.email if config.user else None
|
||||
|
|
@ -276,7 +299,7 @@ def chat(
|
|||
raise typer.Exit()
|
||||
|
||||
|
||||
@cli.command(help="Start an Actor server with specified configurations.")
|
||||
@cli.command(help="Start a local Arcade Actor server", rich_help_panel="Launch")
|
||||
def dev(
|
||||
host: str = typer.Option(
|
||||
"127.0.0.1", help="Host for the app, from settings by default.", show_default=True
|
||||
|
|
@ -300,7 +323,6 @@ def dev(
|
|||
try:
|
||||
serve_default_actor(host, port, disable_auth)
|
||||
except KeyboardInterrupt:
|
||||
console.print("actor stopped by user.", style="bold red")
|
||||
typer.Exit()
|
||||
except Exception as e:
|
||||
error_message = f"❌ Failed to start Arcade Actor: {escape(str(e))}"
|
||||
|
|
@ -308,7 +330,7 @@ def dev(
|
|||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
@cli.command(help="Show/edit configuration details of the Arcade Engine")
|
||||
@cli.command(help="Show/edit the local Arcade configuration", rich_help_panel="User")
|
||||
def config(
|
||||
action: str = typer.Argument("show", help="The action to take (show/edit)"),
|
||||
key: str = typer.Option(
|
||||
|
|
@ -396,7 +418,7 @@ def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
|
|||
console.print(table)
|
||||
|
||||
|
||||
@cli.command(help="Run evaluation suites in a directory")
|
||||
@cli.command(help="Run tool calling evaluations", rich_help_panel="Tool Development")
|
||||
def evals(
|
||||
directory: str = typer.Argument(".", help="Directory containing evaluation files"),
|
||||
show_details: bool = typer.Option(False, "--details", "-d", help="Show detailed results"),
|
||||
|
|
@ -409,11 +431,35 @@ def evals(
|
|||
models: str = typer.Option(
|
||||
"gpt-4o", "--models", "-m", help="The models to use for evaluation (default: gpt-4o)"
|
||||
),
|
||||
host: str = typer.Option(
|
||||
None,
|
||||
"-h",
|
||||
"--host",
|
||||
help="The Arcade Engine address to send chat requests to.",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
None,
|
||||
"-p",
|
||||
"--port",
|
||||
help="The port of the Arcade Engine.",
|
||||
),
|
||||
force_tls: bool = typer.Option(
|
||||
False,
|
||||
"--tls",
|
||||
help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.",
|
||||
),
|
||||
force_no_tls: bool = typer.Option(
|
||||
False,
|
||||
"--no-tls",
|
||||
help="Whether to disable TLS for the connection to the Arcade Engine.",
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Find all files starting with 'eval_' in the given directory,
|
||||
execute any functions decorated with @tool_eval, and display the results.
|
||||
"""
|
||||
config = _get_config_with_overrides(force_tls, force_no_tls, host, port)
|
||||
|
||||
models = models.split(",") # type: ignore[assignment]
|
||||
eval_files = [f for f in os.listdir(directory) if f.startswith("eval_") and f.endswith(".py")]
|
||||
|
||||
|
|
@ -421,6 +467,18 @@ def evals(
|
|||
console.print("No evaluation files found.", style="bold yellow")
|
||||
return
|
||||
|
||||
if show_details:
|
||||
console.print(
|
||||
Text.assemble(
|
||||
("\nRunning evaluations against Arcade Engine at ", "bold"),
|
||||
(config.engine_url, "bold blue"),
|
||||
)
|
||||
)
|
||||
|
||||
# Try to hit /health endpoint on engine and warn if it is down
|
||||
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
|
||||
log_engine_health(client)
|
||||
|
||||
for file in eval_files:
|
||||
file_path = os.path.join(directory, file)
|
||||
module_name = file[:-3] # Remove .py extension
|
||||
|
|
@ -432,17 +490,47 @@ def evals(
|
|||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
|
||||
eval_functions = [
|
||||
eval_suites = [
|
||||
obj
|
||||
for name, obj in module.__dict__.items()
|
||||
if callable(obj) and hasattr(obj, "__tool_eval__")
|
||||
]
|
||||
|
||||
if not eval_functions:
|
||||
if not eval_suites:
|
||||
console.print(f"No @tool_eval functions found in {file}", style="bold yellow")
|
||||
continue
|
||||
|
||||
for func in eval_functions:
|
||||
console.print(f"\nRunning evaluation from {file}: {func.__name__}", style="bold blue")
|
||||
results = func(models=models, max_concurrency=max_concurrent)
|
||||
if show_details:
|
||||
suite_label = "suite" if len(eval_suites) == 1 else "suites"
|
||||
console.print(f"\nFound {len(eval_suites)} {suite_label} in {file}", style="bold")
|
||||
|
||||
for suite_func in eval_suites:
|
||||
console.print(
|
||||
Text.assemble(
|
||||
("\nRunning evaluations in ", "bold"),
|
||||
(suite_func.__name__, "bold blue"),
|
||||
)
|
||||
)
|
||||
results = suite_func(config=config, models=models, max_concurrency=max_concurrent)
|
||||
display_eval_results(results, show_details=show_details)
|
||||
|
||||
|
||||
@cli.command(help="Start an Arcade Cluster instance", rich_help_panel="Launch")
|
||||
def up(
|
||||
host: str = typer.Option("127.0.0.1", help="Host for the actor server.", show_default=True),
|
||||
port: int = typer.Option(
|
||||
8002, "-p", "--port", help="Port for the actor server.", show_default=True
|
||||
),
|
||||
engine_config: str = typer.Option(
|
||||
None, "-c", "--config", help="Path to the engine configuration file."
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Start both the actor and engine servers.
|
||||
"""
|
||||
try:
|
||||
start_servers(host, port, engine_config)
|
||||
except Exception as e:
|
||||
error_message = f"❌ Failed to start servers: {escape(str(e))}"
|
||||
console.print(error_message, style="bold red")
|
||||
raise typer.Exit(code=1)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import fastapi
|
||||
|
|
@ -18,29 +22,73 @@ except ImportError:
|
|||
from arcade.actor.fastapi.actor import FastAPIActor
|
||||
from arcade.core.toolkit import Toolkit
|
||||
|
||||
DEVELOPMENT_SECRET = "dev" # noqa: S105
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno # type: ignore[assignment]
|
||||
|
||||
# Find caller from where originated the logged message
|
||||
frame, depth = sys._getframe(6), 6
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back # type: ignore[assignment]
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def setup_logging(log_level: int = logging.INFO) -> None:
|
||||
# Intercept everything at the root logger
|
||||
logging.root.handlers = [InterceptHandler()]
|
||||
logging.root.setLevel(log_level)
|
||||
|
||||
# Remove every other logger's handlers
|
||||
# and propagate to root logger
|
||||
for name in logging.root.manager.loggerDict:
|
||||
logging.getLogger(name).handlers = []
|
||||
logging.getLogger(name).propagate = True
|
||||
|
||||
# Configure loguru with custom format, no colors
|
||||
logger.configure(
|
||||
handlers=[
|
||||
{
|
||||
"sink": sys.stdout,
|
||||
"serialize": False,
|
||||
"level": log_level,
|
||||
"format": "{time:MM-DD HH:mm:ss} | {level: <8} | {message}"
|
||||
+ (" {name}:{function}:{line}" if log_level <= logging.DEBUG else "")
|
||||
+ ("{exception}\n" if "{exception}" in "{message}" else ""),
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI): # type: ignore[no-untyped-def]
|
||||
try:
|
||||
yield
|
||||
except asyncio.CancelledError:
|
||||
# This is necessary to prevent an unhandled error
|
||||
# when the user presses Ctrl+C
|
||||
logger.debug("Lifespan cancelled.")
|
||||
|
||||
|
||||
def serve_default_actor(
|
||||
host: str = "127.0.0.1", port: int = 8000, disable_auth: bool = False
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8002,
|
||||
disable_auth: bool = False,
|
||||
workers: int = 1,
|
||||
timeout_keep_alive: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Get an instance of a FastAPI server with the Arcade Actor.
|
||||
"""
|
||||
# Use Uvicorn's default log config for Arcade logging,
|
||||
# to ensure a nice consistent style for all logs.
|
||||
logging_config = uvicorn.config.LOGGING_CONFIG
|
||||
logging_config["loggers"]["arcade"] = {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
}
|
||||
|
||||
# TODO: Pass in a logging config from the CLI, to set the log level.
|
||||
logging.config.dictConfig(logging_config)
|
||||
# Setup unified logging
|
||||
setup_logging()
|
||||
|
||||
toolkits = Toolkit.find_all_arcade_toolkits()
|
||||
if not toolkits:
|
||||
|
|
@ -56,12 +104,13 @@ def serve_default_actor(
|
|||
logger.warning(
|
||||
"Warning: ARCADE_ACTOR_SECRET environment variable is not set. Using 'dev' as the actor secret.",
|
||||
)
|
||||
actor_secret = DEVELOPMENT_SECRET
|
||||
actor_secret = actor_secret or "dev"
|
||||
|
||||
app = fastapi.FastAPI(
|
||||
title="Arcade AI Actor",
|
||||
description="Arcade AI default Actor implementation using FastAPI.",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan, # Use custom lifespan to catch errors, notably KeyboardInterrupt (Ctrl+C)
|
||||
)
|
||||
actor = FastAPIActor(app, secret=actor_secret, disable_auth=disable_auth)
|
||||
for toolkit in toolkits:
|
||||
|
|
@ -69,9 +118,27 @@ def serve_default_actor(
|
|||
|
||||
logger.info("Starting FastAPI server...")
|
||||
|
||||
uvicorn.run(
|
||||
class CustomUvicornServer(uvicorn.Server):
|
||||
def install_signal_handlers(self) -> None:
|
||||
pass # Disable Uvicorn's default signal handlers
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_config=logging_config,
|
||||
workers=workers,
|
||||
timeout_keep_alive=timeout_keep_alive,
|
||||
log_config=None,
|
||||
**kwargs,
|
||||
)
|
||||
server = CustomUvicornServer(config=config)
|
||||
|
||||
async def serve() -> None:
|
||||
await server.serve()
|
||||
|
||||
try:
|
||||
asyncio.run(serve())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user.")
|
||||
finally:
|
||||
logger.debug("Server shutdown complete.")
|
||||
|
|
|
|||
|
|
@ -249,18 +249,21 @@ def _format_evaluation(evaluation: "EvaluationResult") -> str:
|
|||
A formatted string representation of the evaluation details.
|
||||
"""
|
||||
result_lines = []
|
||||
for critic_result in evaluation.results:
|
||||
match_color = "green" if critic_result["match"] else "red"
|
||||
field = critic_result["field"]
|
||||
score = critic_result["score"]
|
||||
weight = critic_result["weight"]
|
||||
expected = critic_result["expected"]
|
||||
actual = critic_result["actual"]
|
||||
result_lines.append(
|
||||
f"[bold]{field}:[/bold] "
|
||||
f"[{match_color}]Match: {critic_result['match']}, "
|
||||
f"Score: {score:.2f}/{weight:.2f}[/{match_color}]"
|
||||
f"\n Expected: {expected}"
|
||||
f"\n Actual: {actual}"
|
||||
)
|
||||
if evaluation.failure_reason:
|
||||
result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}")
|
||||
else:
|
||||
for critic_result in evaluation.results:
|
||||
match_color = "green" if critic_result["match"] else "red"
|
||||
field = critic_result["field"]
|
||||
score = critic_result["score"]
|
||||
weight = critic_result["weight"]
|
||||
expected = critic_result["expected"]
|
||||
actual = critic_result["actual"]
|
||||
result_lines.append(
|
||||
f"[bold]{field}:[/bold] "
|
||||
f"[{match_color}]Match: {critic_result['match']}, "
|
||||
f"Score: {score:.2f}/{weight:.2f}[/{match_color}]"
|
||||
f"\n Expected: {expected}"
|
||||
f"\n Actual: {actual}"
|
||||
)
|
||||
return "\n".join(result_lines)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
from typing import Any, Generic, TypeVar
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
|
@ -13,19 +12,20 @@ from arcade.client.errors import (
|
|||
RateLimitError,
|
||||
UnauthorizedError,
|
||||
)
|
||||
from arcade.client.schema import OPENAI_API_VERSION
|
||||
|
||||
T = TypeVar("T")
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
API_VERSION = "v1"
|
||||
BASE_URL = "http://localhost:9099"
|
||||
|
||||
|
||||
class BaseResource(Generic[T]):
|
||||
"""Base class for all resources."""
|
||||
|
||||
def __init__(self, client: T):
|
||||
_path: str
|
||||
|
||||
def __init__(self, client: T) -> None:
|
||||
self._client = client
|
||||
self._resource_path = self._client._base_url + self._path # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class BaseArcadeClient:
|
||||
|
|
@ -33,7 +33,7 @@ class BaseArcadeClient:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = BASE_URL,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | Timeout = 10.0,
|
||||
|
|
@ -49,8 +49,14 @@ class BaseArcadeClient:
|
|||
timeout: Request timeout in seconds.
|
||||
retries: Number of retries for failed requests.
|
||||
"""
|
||||
if base_url is None or api_key is None:
|
||||
from arcade.core.config import config
|
||||
|
||||
base_url = base_url or config.engine_url
|
||||
api_key = api_key or config.api.key
|
||||
self._base_url = base_url
|
||||
self._api_key = api_key or os.environ.get("ARCADE_API_KEY")
|
||||
self._api_key = api_key
|
||||
|
||||
self._headers = headers or {}
|
||||
self._headers.setdefault("Authorization", f"Bearer {self._api_key}")
|
||||
self._headers.setdefault("Content-Type", "application/json")
|
||||
|
|
@ -65,8 +71,8 @@ class BaseArcadeClient:
|
|||
|
||||
def _chat_url(self, base_url: str) -> str:
|
||||
chat_url = str(base_url)
|
||||
if not base_url.endswith(API_VERSION):
|
||||
chat_url = f"{base_url}/{API_VERSION}"
|
||||
if not base_url.endswith(OPENAI_API_VERSION):
|
||||
chat_url = f"{base_url}/{OPENAI_API_VERSION}"
|
||||
return chat_url
|
||||
|
||||
def _handle_http_error(self, e: httpx.HTTPStatusError) -> None:
|
||||
|
|
@ -80,7 +86,10 @@ class BaseArcadeClient:
|
|||
}
|
||||
status_code = e.response.status_code
|
||||
error_class = error_map.get(status_code, InternalServerError)
|
||||
raise error_class(str(e), response=e.response)
|
||||
msg = e.response.json()
|
||||
if isinstance(msg, dict) and "error" in msg:
|
||||
raise error_class(msg["error"], response=e.response) from None
|
||||
raise error_class(msg, response=e.response) from None
|
||||
|
||||
|
||||
class SyncArcadeClient(BaseArcadeClient):
|
||||
|
|
@ -94,7 +103,7 @@ class SyncArcadeClient(BaseArcadeClient):
|
|||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response:
|
||||
def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response: # type: ignore[return]
|
||||
"""
|
||||
Make a synchronous HTTP request.
|
||||
"""
|
||||
|
|
@ -104,10 +113,9 @@ class SyncArcadeClient(BaseArcadeClient):
|
|||
response = self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response # noqa: TRY300
|
||||
except httpx.HTTPStatusError:
|
||||
except httpx.HTTPStatusError as e:
|
||||
if attempt == self._retries - 1:
|
||||
raise
|
||||
raise RuntimeError("This should never be reached")
|
||||
self._handle_http_error(e)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client session."""
|
||||
|
|
@ -139,7 +147,7 @@ class AsyncArcadeClient(BaseArcadeClient):
|
|||
)
|
||||
return self._client
|
||||
|
||||
async def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response:
|
||||
async def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response: # type: ignore[return]
|
||||
"""
|
||||
Make an asynchronous HTTP request.
|
||||
"""
|
||||
|
|
@ -150,10 +158,9 @@ class AsyncArcadeClient(BaseArcadeClient):
|
|||
response = await client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response # noqa: TRY300
|
||||
except httpx.HTTPStatusError:
|
||||
except httpx.HTTPStatusError as e:
|
||||
if attempt == self._retries - 1:
|
||||
raise
|
||||
raise RuntimeError("This should never be reached")
|
||||
self._handle_http_error(e)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the client session."""
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
from typing import Any, TypeVar, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.resources.chat import AsyncChat, Chat
|
||||
|
||||
from arcade.client.base import (
|
||||
API_VERSION,
|
||||
AsyncArcadeClient,
|
||||
BaseResource,
|
||||
SyncArcadeClient,
|
||||
|
|
@ -27,7 +25,7 @@ ClientT = TypeVar("ClientT", SyncArcadeClient, AsyncArcadeClient)
|
|||
class AuthResource(BaseResource[ClientT]):
|
||||
"""Authentication resource."""
|
||||
|
||||
_base_path = f"/{API_VERSION}/auth"
|
||||
_path = "/auth"
|
||||
|
||||
def authorize(
|
||||
self,
|
||||
|
|
@ -59,7 +57,7 @@ class AuthResource(BaseResource[ClientT]):
|
|||
|
||||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"POST",
|
||||
f"{self._base_path}/authorize",
|
||||
f"{self._resource_path}/authorize",
|
||||
json=body,
|
||||
)
|
||||
return AuthResponse(**data)
|
||||
|
|
@ -85,7 +83,7 @@ class AuthResource(BaseResource[ClientT]):
|
|||
|
||||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._base_path}/status",
|
||||
f"{self._resource_path}/status",
|
||||
params={"authorizationId": auth_id, "scopes": " ".join(scopes) if scopes else None},
|
||||
)
|
||||
return AuthResponse(**data)
|
||||
|
|
@ -94,7 +92,7 @@ class AuthResource(BaseResource[ClientT]):
|
|||
class ToolResource(BaseResource[ClientT]):
|
||||
"""Tool resource."""
|
||||
|
||||
_base_path = f"/{API_VERSION}/tool"
|
||||
_path = "/tools"
|
||||
|
||||
def run(
|
||||
self,
|
||||
|
|
@ -119,7 +117,7 @@ class ToolResource(BaseResource[ClientT]):
|
|||
"inputs": inputs,
|
||||
}
|
||||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"POST", f"{self._base_path}/execute", json=request_data
|
||||
"POST", f"{self._resource_path}/execute", json=request_data
|
||||
)
|
||||
return ExecuteToolResponse(**data)
|
||||
|
||||
|
|
@ -129,19 +127,21 @@ class ToolResource(BaseResource[ClientT]):
|
|||
"""
|
||||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._base_path}/definition",
|
||||
f"{self._resource_path}/definition",
|
||||
params={"directorId": director_id, "toolId": tool_id},
|
||||
)
|
||||
return ToolDefinition(**data)
|
||||
|
||||
def authorize(self, tool_name: str, user_id: str) -> AuthResponse:
|
||||
def authorize(
|
||||
self, tool_name: str, user_id: str, tool_version: str | None = None
|
||||
) -> AuthResponse:
|
||||
"""
|
||||
Get the authorization status for a tool.
|
||||
"""
|
||||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"POST",
|
||||
f"{self._base_path}/authorize",
|
||||
json={"tool_name": tool_name, "user_id": user_id},
|
||||
f"{self._resource_path}/authorize",
|
||||
json={"tool_name": tool_name, "tool_version": tool_version, "user_id": user_id},
|
||||
)
|
||||
return AuthResponse(**data)
|
||||
|
||||
|
|
@ -149,6 +149,8 @@ class ToolResource(BaseResource[ClientT]):
|
|||
class HealthResource(BaseResource[ClientT]):
|
||||
"""Health check resource."""
|
||||
|
||||
_path = "/health"
|
||||
|
||||
def check(self) -> None:
|
||||
"""
|
||||
Check the health of the Arcade Engine.
|
||||
|
|
@ -158,7 +160,7 @@ class HealthResource(BaseResource[ClientT]):
|
|||
try:
|
||||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"/{API_VERSION}/health",
|
||||
f"{self._resource_path}",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
|
|
@ -184,7 +186,7 @@ class HealthResource(BaseResource[ClientT]):
|
|||
class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
|
||||
"""Asynchronous Authentication resource."""
|
||||
|
||||
_base_path = f"/{API_VERSION}/auth"
|
||||
_path = "/auth"
|
||||
|
||||
async def authorize(
|
||||
self,
|
||||
|
|
@ -210,7 +212,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
|
|||
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"POST",
|
||||
f"{self._base_path}/authorize",
|
||||
f"{self._resource_path}/authorize",
|
||||
json=body,
|
||||
)
|
||||
return AuthResponse(**data)
|
||||
|
|
@ -236,7 +238,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
|
|||
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._base_path}/status",
|
||||
f"{self._resource_path}/status",
|
||||
params={"authorizationId": auth_id, "scopes": " ".join(scopes) if scopes else None},
|
||||
)
|
||||
return AuthResponse(**data)
|
||||
|
|
@ -245,7 +247,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
|
|||
class AsyncToolResource(BaseResource[AsyncArcadeClient]):
|
||||
"""Asynchronous Tool resource."""
|
||||
|
||||
_base_path = f"/{API_VERSION}/tools"
|
||||
_path = "/tools"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
@ -264,7 +266,7 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
|
|||
"inputs": inputs,
|
||||
}
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"POST", f"{self._base_path}/execute", json=request_data
|
||||
"POST", f"{self._resource_path}/execute", json=request_data
|
||||
)
|
||||
return ExecuteToolResponse(**data)
|
||||
|
||||
|
|
@ -274,19 +276,21 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
|
|||
"""
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._base_path}/definition",
|
||||
f"{self._resource_path}/definition",
|
||||
params={"directorId": director_id, "toolId": tool_id},
|
||||
)
|
||||
return ToolDefinition(**data)
|
||||
|
||||
async def authorize(self, tool_name: str, user_id: str) -> AuthResponse:
|
||||
async def authorize(
|
||||
self, tool_name: str, user_id: str, tool_version: str | None = None
|
||||
) -> AuthResponse:
|
||||
"""
|
||||
Get the authorization status for a tool.
|
||||
"""
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"POST",
|
||||
f"{self._base_path}/authorize",
|
||||
json={"tool_name": tool_name, "user_id": user_id},
|
||||
f"{self._resource_path}/authorize",
|
||||
json={"tool_name": tool_name, "tool_version": tool_version, "user_id": user_id},
|
||||
)
|
||||
return AuthResponse(**data)
|
||||
|
||||
|
|
@ -294,6 +298,8 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
|
|||
class AsyncHealthResource(BaseResource[AsyncArcadeClient]):
|
||||
"""Asynchronous Health check resource."""
|
||||
|
||||
_path = "/health"
|
||||
|
||||
async def check(self) -> None:
|
||||
"""
|
||||
Check the health of the Arcade Engine.
|
||||
|
|
@ -303,7 +309,7 @@ class AsyncHealthResource(BaseResource[AsyncArcadeClient]):
|
|||
try:
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"/{API_VERSION}/health",
|
||||
f"{self._resource_path}",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
|
|
@ -332,7 +338,7 @@ class Arcade(SyncArcadeClient):
|
|||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.auth: AuthResource = AuthResource(self)
|
||||
self.tool: ToolResource = ToolResource(self)
|
||||
self.tools: ToolResource = ToolResource(self)
|
||||
self.health: HealthResource = HealthResource(self)
|
||||
chat_url = self._chat_url(self._base_url)
|
||||
self._openai_client = OpenAI(base_url=chat_url, api_key=self._api_key)
|
||||
|
|
@ -345,11 +351,8 @@ class Arcade(SyncArcadeClient):
|
|||
"""
|
||||
Execute a synchronous request.
|
||||
"""
|
||||
try:
|
||||
response = self._request(method, url, **kwargs)
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
self._handle_http_error(e)
|
||||
response = self._request(method, url, **kwargs)
|
||||
return response.json()
|
||||
|
||||
|
||||
class AsyncArcade(AsyncArcadeClient):
|
||||
|
|
@ -358,7 +361,7 @@ class AsyncArcade(AsyncArcadeClient):
|
|||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.auth: AsyncAuthResource = AsyncAuthResource(self)
|
||||
self.tool: AsyncToolResource = AsyncToolResource(self)
|
||||
self.tools: AsyncToolResource = AsyncToolResource(self)
|
||||
self.health: AsyncHealthResource = AsyncHealthResource(self)
|
||||
chat_url = self._chat_url(self._base_url)
|
||||
self._openai_client = AsyncOpenAI(base_url=chat_url, api_key=self._api_key)
|
||||
|
|
@ -371,8 +374,5 @@ class AsyncArcade(AsyncArcadeClient):
|
|||
"""
|
||||
Execute an asynchronous request.
|
||||
"""
|
||||
try:
|
||||
response = await self._request(method, url, **kwargs)
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
self._handle_http_error(e)
|
||||
response = await self._request(method, url, **kwargs)
|
||||
return response.json()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field
|
||||
|
||||
from arcade.core.schema import ToolAuthorizationContext, ToolCallOutput
|
||||
|
||||
OPENAI_API_VERSION = os.getenv("OPENAI_API_VERSION", "v1")
|
||||
|
||||
|
||||
class AuthProvider(str, Enum):
|
||||
"""The supported authorization providers."""
|
||||
|
|
|
|||
|
|
@ -48,8 +48,6 @@ from arcade.core.utils import (
|
|||
from arcade.sdk.annotations import Inferrable
|
||||
from arcade.sdk.auth import BaseOAuth2, ToolAuthorization
|
||||
|
||||
DEFAULT_TOOLKIT_NAME = "Tools"
|
||||
|
||||
InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
|
||||
WireType = Union[InnerWireType, Literal["array"]]
|
||||
|
||||
|
|
@ -116,7 +114,7 @@ class ToolCatalog(BaseModel):
|
|||
def add_tool(
|
||||
self,
|
||||
tool_func: Callable,
|
||||
toolkit_or_name: Union[str | None, Toolkit] = None,
|
||||
toolkit_or_name: Union[str, Toolkit],
|
||||
module: ModuleType | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -131,9 +129,6 @@ class ToolCatalog(BaseModel):
|
|||
elif isinstance(toolkit_or_name, str):
|
||||
toolkit = None
|
||||
toolkit_name = toolkit_or_name
|
||||
else:
|
||||
toolkit = None
|
||||
toolkit_name = DEFAULT_TOOLKIT_NAME
|
||||
|
||||
if not toolkit_name:
|
||||
raise ValueError("A toolkit name or toolkit must be provided.")
|
||||
|
|
@ -163,6 +158,13 @@ class ToolCatalog(BaseModel):
|
|||
output_model=output_model,
|
||||
)
|
||||
|
||||
def add_module(self, module: ModuleType) -> None:
|
||||
"""
|
||||
Add all the tools in a module to the catalog.
|
||||
"""
|
||||
toolkit = Toolkit.from_module(module)
|
||||
self.add_toolkit(toolkit)
|
||||
|
||||
def add_toolkit(self, toolkit: Toolkit) -> None:
|
||||
"""
|
||||
Add the tools from a loaded toolkit to the catalog.
|
||||
|
|
@ -201,6 +203,15 @@ class ToolCatalog(BaseModel):
|
|||
def get_tool_names(self) -> list[FullyQualifiedName]:
|
||||
return [tool.definition.get_fully_qualified_name() for tool in self._tools.values()]
|
||||
|
||||
def find_tool_by_func(self, func: Callable) -> ToolDefinition:
|
||||
"""
|
||||
Find a tool by its function.
|
||||
"""
|
||||
for _, tool in self._tools.items():
|
||||
if tool.tool == func:
|
||||
return tool.definition
|
||||
raise ValueError(f"Tool {func} not found in the catalog.")
|
||||
|
||||
def get_tool(self, name: FullyQualifiedName) -> MaterializedTool:
|
||||
"""
|
||||
Get a tool from the catalog by fully-qualified name and version.
|
||||
|
|
|
|||
|
|
@ -1,16 +1,19 @@
|
|||
import ipaddress
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import idna
|
||||
import toml
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from arcade.core.env import settings
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
|
||||
class ApiConfig(BaseModel):
|
||||
class BaseConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
class ApiConfig(BaseConfig):
|
||||
"""
|
||||
Arcade API configuration.
|
||||
"""
|
||||
|
|
@ -19,9 +22,13 @@ class ApiConfig(BaseModel):
|
|||
"""
|
||||
Arcade API key.
|
||||
"""
|
||||
version: str = "v1"
|
||||
"""
|
||||
Arcade API version.
|
||||
"""
|
||||
|
||||
|
||||
class UserConfig(BaseModel):
|
||||
class UserConfig(BaseConfig):
|
||||
"""
|
||||
Arcade user configuration.
|
||||
"""
|
||||
|
|
@ -32,7 +39,7 @@ class UserConfig(BaseModel):
|
|||
"""
|
||||
|
||||
|
||||
class EngineConfig(BaseModel):
|
||||
class EngineConfig(BaseConfig):
|
||||
"""
|
||||
Arcade Engine configuration.
|
||||
"""
|
||||
|
|
@ -51,7 +58,7 @@ class EngineConfig(BaseModel):
|
|||
"""
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
class Config(BaseConfig):
|
||||
"""
|
||||
Configuration for Arcade.
|
||||
"""
|
||||
|
|
@ -79,7 +86,8 @@ class Config(BaseModel):
|
|||
"""
|
||||
Get the path to the Arcade configuration directory.
|
||||
"""
|
||||
return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
|
||||
config_path = os.getenv("ARCADE_WORK_DIR") or Path.home() / ".arcade"
|
||||
return Path(config_path).resolve()
|
||||
|
||||
@classmethod
|
||||
def get_config_file_path(cls) -> Path:
|
||||
|
|
@ -167,14 +175,14 @@ class Config(BaseModel):
|
|||
if ":" in parsed_host.netloc and not is_ip:
|
||||
host, existing_port = parsed_host.netloc.rsplit(":", 1)
|
||||
if existing_port.isdigit():
|
||||
return f"{protocol}://{parsed_host.netloc}/v1"
|
||||
return f"{protocol}://{parsed_host.netloc}/{self.api.version}"
|
||||
|
||||
if is_fqdn and self.engine.port is None:
|
||||
return f"{protocol}://{encoded_host}/v1"
|
||||
return f"{protocol}://{encoded_host}/{self.api.version}"
|
||||
elif self.engine.port is not None:
|
||||
return f"{protocol}://{encoded_host}:{self.engine.port}/v1"
|
||||
return f"{protocol}://{encoded_host}:{self.engine.port}/{self.api.version}"
|
||||
else:
|
||||
return f"{protocol}://{encoded_host}/v1"
|
||||
return f"{protocol}://{encoded_host}/{self.api.version}"
|
||||
|
||||
@classmethod
|
||||
def ensure_config_dir_exists(cls) -> None:
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env")
|
||||
|
||||
WORK_DIR: Path = Path.home() / ".arcade"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
# env_file = os.getenv("ARCADE_ENV_FILE")
|
||||
# TODO allow env override
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field
|
||||
|
||||
TOOL_NAME_SEPARATOR = "."
|
||||
# allow for custom tool name separator
|
||||
TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".")
|
||||
|
||||
|
||||
class ValueSchema(BaseModel):
|
||||
|
|
|
|||
|
|
@ -108,14 +108,19 @@ class Toolkit(BaseModel):
|
|||
@classmethod
|
||||
def find_all_arcade_toolkits(cls) -> list["Toolkit"]:
|
||||
"""
|
||||
Find all installed packages prefixed with 'arcade_' and load them as Toolkits.
|
||||
Find all installed packages prefixed with 'arcade_' in the current
|
||||
Python interpreter's environment and load them as Toolkits.
|
||||
|
||||
Returns:
|
||||
List[Toolkit]: A list of Toolkit instances.
|
||||
"""
|
||||
import sysconfig
|
||||
|
||||
# Get the site-packages directory of the current interpreter
|
||||
site_packages_dir = sysconfig.get_paths()["purelib"]
|
||||
arcade_packages = [
|
||||
dist.metadata["Name"]
|
||||
for dist in importlib.metadata.distributions()
|
||||
for dist in importlib.metadata.distributions(path=[site_packages_dir])
|
||||
if dist.metadata["Name"].startswith("arcade_")
|
||||
]
|
||||
return [cls.from_package(package) for package in arcade_packages]
|
||||
|
|
|
|||
|
|
@ -1,21 +1,5 @@
|
|||
from .eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
NumericCritic,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
from .tool import tool
|
||||
|
||||
__all__ = [
|
||||
"tool",
|
||||
"EvalRubric",
|
||||
"EvalSuite",
|
||||
"ExpectedToolCall",
|
||||
"tool_eval",
|
||||
"BinaryCritic",
|
||||
"SimilarityCritic",
|
||||
"NumericCritic",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import json
|
|||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from arcade.core.config_model import Config
|
||||
from arcade.core.schema import FullyQualifiedName
|
||||
|
||||
try:
|
||||
|
|
@ -11,9 +12,10 @@ try:
|
|||
from scipy.optimize import linear_sum_assignment
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Use `pip install arcade[evals]` to install the required dependencies for evaluation."
|
||||
"Use `pip install arcade-ai[evals]` to install the required dependencies for evaluation."
|
||||
)
|
||||
|
||||
|
||||
from arcade.client.client import Arcade, AsyncArcade
|
||||
from arcade.sdk.error import WeightError
|
||||
|
||||
|
|
@ -69,12 +71,15 @@ class EvaluationResult:
|
|||
passed: Whether the evaluation passed based on the fail_threshold.
|
||||
warning: Whether the evaluation issued a warning based on the warn_threshold.
|
||||
results: A list of dictionaries containing the results for each critic.
|
||||
failure_reason: If the evaluation failed completely due to settings in the rubric,
|
||||
this field contains the reason for failure.
|
||||
"""
|
||||
|
||||
score: float = 0.0
|
||||
passed: bool = False
|
||||
warning: bool = False
|
||||
results: list[dict[str, Any]] = field(default_factory=list)
|
||||
failure_reason: str | None = None
|
||||
|
||||
@property
|
||||
def fail(self) -> bool:
|
||||
|
|
@ -120,10 +125,10 @@ class EvaluationResult:
|
|||
Returns:
|
||||
The score for the tool selection.
|
||||
"""
|
||||
score = weight if expected == actual else 0.0
|
||||
score = weight if compare_tool_name(expected, actual) else 0.0
|
||||
self.add(
|
||||
"tool_selection",
|
||||
{"match": expected == actual, "score": score},
|
||||
{"match": compare_tool_name(expected, actual), "score": score},
|
||||
weight,
|
||||
expected,
|
||||
actual,
|
||||
|
|
@ -190,7 +195,10 @@ class EvalCase:
|
|||
True if tool selection failure should occur, False otherwise.
|
||||
"""
|
||||
expected_tools = [tc.name for tc in self.expected_tool_calls]
|
||||
return self.rubric.fail_on_tool_selection and set(expected_tools) != set(actual_tools)
|
||||
return self.rubric.fail_on_tool_selection and not all(
|
||||
compare_tool_name(expected, actual)
|
||||
for expected, actual in zip(expected_tools, actual_tools)
|
||||
)
|
||||
|
||||
def check_tool_call_quantity_failure(self, actual_count: int) -> bool:
|
||||
"""
|
||||
|
|
@ -218,17 +226,30 @@ class EvalCase:
|
|||
evaluation_result = EvaluationResult()
|
||||
actual_tools = [tool for tool, _ in actual_tool_calls]
|
||||
|
||||
if self.check_tool_selection_failure(actual_tools):
|
||||
evaluation_result.score = 0.0
|
||||
evaluation_result.passed = False
|
||||
evaluation_result.warning = False
|
||||
return evaluation_result
|
||||
|
||||
actual_count = len(actual_tool_calls)
|
||||
if self.check_tool_call_quantity_failure(actual_count):
|
||||
evaluation_result.score = 0.0
|
||||
evaluation_result.passed = False
|
||||
evaluation_result.warning = False
|
||||
expected_count = len(self.expected_tool_calls)
|
||||
evaluation_result.failure_reason = (
|
||||
f"Expected {expected_count} tool call(s), but got {actual_count}"
|
||||
)
|
||||
return evaluation_result
|
||||
|
||||
# check if no tools should be called and none were called
|
||||
if not self.expected_tool_calls and not actual_tools:
|
||||
evaluation_result.score = 1.0
|
||||
evaluation_result.passed = True
|
||||
evaluation_result.warning = False
|
||||
return evaluation_result
|
||||
|
||||
if self.check_tool_selection_failure(actual_tools):
|
||||
evaluation_result.score = 0.0
|
||||
evaluation_result.passed = False
|
||||
evaluation_result.warning = False
|
||||
expected_tools = [tc.name for tc in self.expected_tool_calls]
|
||||
evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}"
|
||||
return evaluation_result
|
||||
|
||||
# Create a cost matrix for the assignment problem
|
||||
|
|
@ -422,12 +443,10 @@ class EvalSuite:
|
|||
max_concurrent: int = 1 # Default to sequential execution
|
||||
_client: AsyncArcade | Arcade | None = None
|
||||
|
||||
def initialize_client(self) -> None:
|
||||
def initialize_client(self, config: Config) -> None:
|
||||
"""
|
||||
Initialize the client instance for the EvalSuite.
|
||||
"""
|
||||
from arcade.core.config import config
|
||||
|
||||
if self.max_concurrent > 1:
|
||||
self._client = AsyncArcade(
|
||||
api_key=config.api.key,
|
||||
|
|
@ -443,7 +462,7 @@ class EvalSuite:
|
|||
self,
|
||||
name: str,
|
||||
user_message: str,
|
||||
expected_tool_calls: list[ExpectedToolCall],
|
||||
expected_tool_calls: list[tuple[Callable, dict[str, Any]]],
|
||||
critics: list["Critic"],
|
||||
system_message: str | None = None,
|
||||
rubric: EvalRubric | None = None,
|
||||
|
|
@ -461,11 +480,18 @@ class EvalSuite:
|
|||
rubric: The evaluation rubric for this case.
|
||||
additional_messages: Optional list of additional messages for context.
|
||||
"""
|
||||
expected = [
|
||||
ExpectedToolCall(
|
||||
name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()),
|
||||
args=args,
|
||||
)
|
||||
for func, args in expected_tool_calls
|
||||
]
|
||||
case = EvalCase(
|
||||
name=name,
|
||||
system_message=system_message or self.system_message,
|
||||
user_message=user_message,
|
||||
expected_tool_calls=expected_tool_calls,
|
||||
expected_tool_calls=expected,
|
||||
rubric=rubric or self.rubric,
|
||||
critics=critics,
|
||||
additional_messages=additional_messages or [],
|
||||
|
|
@ -477,7 +503,7 @@ class EvalSuite:
|
|||
name: str,
|
||||
user_message: str,
|
||||
system_message: str | None = None,
|
||||
expected_tool_calls: list[ExpectedToolCall] | None = None,
|
||||
expected_tool_calls: list[tuple[Callable, dict[str, Any]]] | None = None,
|
||||
rubric: EvalRubric | None = None,
|
||||
critics: list["Critic"] | None = None,
|
||||
additional_messages: list[dict[str, str]] | None = None,
|
||||
|
|
@ -507,12 +533,22 @@ class EvalSuite:
|
|||
if additional_messages:
|
||||
new_additional_messages.extend(additional_messages)
|
||||
|
||||
expected = last_case.expected_tool_calls
|
||||
if expected_tool_calls:
|
||||
expected = [
|
||||
ExpectedToolCall(
|
||||
name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()),
|
||||
args=args,
|
||||
)
|
||||
for func, args in expected_tool_calls
|
||||
]
|
||||
|
||||
# Create a new case, copying from the last one and updating fields
|
||||
new_case = EvalCase(
|
||||
name=name,
|
||||
system_message=system_message or last_case.system_message,
|
||||
user_message=user_message,
|
||||
expected_tool_calls=expected_tool_calls or last_case.expected_tool_calls,
|
||||
expected_tool_calls=expected,
|
||||
rubric=rubric or self.rubric,
|
||||
critics=critics or last_case.critics.copy(),
|
||||
additional_messages=new_additional_messages,
|
||||
|
|
@ -570,7 +606,7 @@ class EvalSuite:
|
|||
|
||||
return results
|
||||
|
||||
def run(self, model: str) -> dict[str, Any]:
|
||||
def run(self, config: Config, model: str) -> dict[str, Any]:
|
||||
"""
|
||||
Run the evaluation suite.
|
||||
|
||||
|
|
@ -581,7 +617,7 @@ class EvalSuite:
|
|||
A dictionary containing the evaluation results.
|
||||
"""
|
||||
if not self._client:
|
||||
self.initialize_client()
|
||||
self.initialize_client(config)
|
||||
|
||||
if self.max_concurrent > 1:
|
||||
# Run asynchronously with concurrency
|
||||
|
|
@ -614,10 +650,26 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
|
|||
return tool_args_list
|
||||
|
||||
|
||||
def compare_tool_name(expected: str, actual: str) -> bool:
|
||||
"""
|
||||
Compare the tool name without penalizing for mismatch in separators
|
||||
between module names and tool names ex. '-' vs '_' vs '.' vs ' '
|
||||
"""
|
||||
# TODO optimize this
|
||||
# Remove all separators from both names
|
||||
separators = "-_."
|
||||
expected_clean = "".join(char for char in expected if char not in separators)
|
||||
actual_clean = "".join(char for char in actual if char not in separators)
|
||||
|
||||
# Compare the cleaned names
|
||||
return expected_clean == actual_clean
|
||||
|
||||
|
||||
def tool_eval() -> Callable[[Callable], Callable]:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(
|
||||
config: Config,
|
||||
models: list[str],
|
||||
max_concurrency: int = 1,
|
||||
) -> list[dict[str, Any]]:
|
||||
|
|
@ -627,7 +679,7 @@ def tool_eval() -> Callable[[Callable], Callable]:
|
|||
suite.max_concurrent = max_concurrency
|
||||
results = []
|
||||
for model in models:
|
||||
result = suite.run(model)
|
||||
result = suite.run(config, model)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -15,14 +15,13 @@ build-backend = "poetry.core.masonry.api"
|
|||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
pydantic = "^2.7.0"
|
||||
pydantic-settings = "^2.2.1"
|
||||
typer = "^0.9.0"
|
||||
rich = "^13.7.1"
|
||||
toml = "^0.10.2"
|
||||
tomlkit = "^0.12.4"
|
||||
requests = "^2.26.0" # TODO: is this really needed?
|
||||
openai = "^1.36.0" # TODO: relax to an earlier version that still has what we need
|
||||
pyjwt = "^2.8.0"
|
||||
loguru = "^0.7.0"
|
||||
|
||||
|
||||
[tool.poetry.group.fastapi.dependencies]
|
||||
|
|
@ -115,7 +114,9 @@ ignore = [ # TODO work to remove these
|
|||
# raise from (cli specific)
|
||||
"B904",
|
||||
# long message exceptions
|
||||
"TRY003"
|
||||
"TRY003",
|
||||
# subprocess.Popen
|
||||
"S603",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
|
|
|
|||
|
|
@ -68,6 +68,18 @@ HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA = {
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_sync_client():
|
||||
"""Test client."""
|
||||
return Arcade(base_url="http://arcade.example.com", api_key="fake_api_key")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_async_client():
|
||||
"""Test client."""
|
||||
return AsyncArcade(base_url="http://arcade.example.com", api_key="fake_api_key")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
"""Mock Response object for testing."""
|
||||
|
|
@ -94,7 +106,7 @@ def mock_async_response():
|
|||
(500, InternalServerError),
|
||||
],
|
||||
)
|
||||
def test_handle_http_error(error_code, expected_error, mock_response):
|
||||
def test_handle_http_error(test_sync_client, error_code, expected_error, mock_response):
|
||||
"""Test _handle_http_error method for different error codes."""
|
||||
mock_response.status_code = error_code
|
||||
mock_response.json.return_value = {"error": "Test error message"}
|
||||
|
|
@ -103,16 +115,14 @@ def test_handle_http_error(error_code, expected_error, mock_response):
|
|||
mock_http_error = Mock(spec=HTTPStatusError)
|
||||
mock_http_error.response = mock_response
|
||||
|
||||
client = Arcade(api_key="fake_api_key") # Create an instance of Arcade
|
||||
with pytest.raises(expected_error):
|
||||
client._handle_http_error(mock_http_error) # Call the method on the instance
|
||||
test_sync_client._handle_http_error(mock_http_error) # Call the method on the instance
|
||||
|
||||
|
||||
def test_arcade_auth_authorize(mock_response, monkeypatch):
|
||||
def test_arcade_auth_authorize(test_sync_client, mock_response, monkeypatch):
|
||||
"""Test Arcade.auth.authorize method."""
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
auth_response = client.auth.authorize(
|
||||
auth_response = test_sync_client.auth.authorize(
|
||||
provider=AuthProvider.google,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
user_id="sam@arcade-ai.com",
|
||||
|
|
@ -120,19 +130,17 @@ def test_arcade_auth_authorize(mock_response, monkeypatch):
|
|||
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
|
||||
|
||||
|
||||
def test_arcade_auth_poll_authorization(mock_response, monkeypatch):
|
||||
def test_arcade_auth_poll_authorization(test_sync_client, mock_response, monkeypatch):
|
||||
"""Test Arcade.auth.poll_authorization method."""
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
auth_response = client.auth.status("auth_123")
|
||||
auth_response = test_sync_client.auth.status("auth_123")
|
||||
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
|
||||
|
||||
|
||||
def test_arcade_tool_run(mock_response, monkeypatch):
|
||||
"""Test Arcade.tool.run method."""
|
||||
def test_arcade_tool_run(test_sync_client, mock_response, monkeypatch):
|
||||
"""Test Arcade.tools.run method."""
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_RESPONSE_DATA)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
tool_response = client.tool.run(
|
||||
tool_response = test_sync_client.tools.run(
|
||||
tool_name="GetEmails",
|
||||
user_id="sam@arcade-ai.com",
|
||||
tool_version="0.1.0",
|
||||
|
|
@ -141,54 +149,51 @@ def test_arcade_tool_run(mock_response, monkeypatch):
|
|||
assert tool_response == ExecuteToolResponse(**TOOL_RESPONSE_DATA)
|
||||
|
||||
|
||||
def test_arcade_tool_get(mock_response, monkeypatch):
|
||||
"""Test Arcade.tool.get method."""
|
||||
def test_arcade_tool_get(test_sync_client, mock_response, monkeypatch):
|
||||
"""Test Arcade.tools.get method."""
|
||||
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_DEFINITION_DATA)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
tool_definition = client.tool.get(director_id="default", tool_id="GetEmails")
|
||||
tool_definition = test_sync_client.tools.get(director_id="default", tool_id="GetEmails")
|
||||
assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA)
|
||||
|
||||
|
||||
def test_arcade_tool_authorize(mock_response, monkeypatch):
|
||||
"""Test Arcade.tool.authorize method."""
|
||||
def test_arcade_tool_authorize(test_sync_client, mock_response, monkeypatch):
|
||||
"""Test Arcade.tools.authorize method."""
|
||||
monkeypatch.setattr(
|
||||
Arcade, "_execute_request", lambda *args, **kwargs: TOOL_AUTHORIZE_RESPONSE_DATA
|
||||
)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
auth_response = client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com")
|
||||
auth_response = test_sync_client.tools.authorize(
|
||||
tool_name="GetEmails", user_id="sam@arcade-ai.com"
|
||||
)
|
||||
assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA)
|
||||
|
||||
|
||||
def test_arcade_health_check(mock_response, monkeypatch):
|
||||
def test_arcade_health_check(test_sync_client, mock_response, monkeypatch):
|
||||
"""Test Arcade.health.check method."""
|
||||
monkeypatch.setattr(
|
||||
Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_HEALTHY_RESPONSE_DATA
|
||||
)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
client.health.check()
|
||||
test_sync_client.health.check()
|
||||
assert True # If no exception is raised, the test passes
|
||||
|
||||
|
||||
def test_arcade_health_check_raises_error(mock_response, monkeypatch):
|
||||
def test_arcade_health_check_raises_error(test_sync_client, mock_response, monkeypatch):
|
||||
"""Test Arcade.health.check method."""
|
||||
monkeypatch.setattr(
|
||||
Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA
|
||||
)
|
||||
client = Arcade(api_key="fake_api_key")
|
||||
with pytest.raises(EngineNotHealthyError):
|
||||
client.health.check()
|
||||
test_sync_client.health.check()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch):
|
||||
async def test_async_arcade_auth_authorize(test_async_client, mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.auth.authorize method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return AUTH_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
auth_response = await client.auth.authorize(
|
||||
auth_response = await test_async_client.auth.authorize(
|
||||
provider=AuthProvider.google,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
user_id="sam@arcade-ai.com",
|
||||
|
|
@ -197,28 +202,28 @@ async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_auth_poll_authorization(mock_async_response, monkeypatch):
|
||||
async def test_async_arcade_auth_poll_authorization(
|
||||
test_async_client, mock_async_response, monkeypatch
|
||||
):
|
||||
"""Test AsyncArcade.auth.poll_authorization method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return AUTH_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
auth_response = await client.auth.status("auth_123")
|
||||
auth_response = await test_async_client.auth.status("auth_123")
|
||||
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_tool_run(mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.tool.run method."""
|
||||
async def test_async_arcade_tool_run(test_async_client, mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.tools.run method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return TOOL_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
tool_response = await client.tool.run(
|
||||
tool_response = await test_async_client.tools.run(
|
||||
tool_name="GetEmails",
|
||||
user_id="sam@arcade-ai.com",
|
||||
tool_version="0.1.0",
|
||||
|
|
@ -228,52 +233,52 @@ async def test_async_arcade_tool_run(mock_async_response, monkeypatch):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_tool_get(mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.tool.get method."""
|
||||
async def test_async_arcade_tool_get(test_async_client, mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.tools.get method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return TOOL_DEFINITION_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
tool_definition = await client.tool.get(director_id="default", tool_id="GetEmails")
|
||||
tool_definition = await test_async_client.tools.get(director_id="default", tool_id="GetEmails")
|
||||
assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_tool_authorize(mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.tool.authorize method."""
|
||||
async def test_async_arcade_tool_authorize(test_async_client, mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.tools.authorize method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return TOOL_AUTHORIZE_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
auth_response = await client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com")
|
||||
auth_response = await test_async_client.tools.authorize(
|
||||
tool_name="GetEmails", user_id="sam@arcade-ai.com"
|
||||
)
|
||||
assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_health_check(mock_async_response, monkeypatch):
|
||||
async def test_async_arcade_health_check(test_async_client, mock_async_response, monkeypatch):
|
||||
"""Test AsyncArcade.health.check method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return HEALTH_CHECK_HEALTHY_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
await client.health.check()
|
||||
await test_async_client.health.check()
|
||||
assert True # If no exception is raised, the test passes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_arcade_health_check_raises_error(mock_async_response, monkeypatch):
|
||||
async def test_async_arcade_health_check_raises_error(
|
||||
test_async_client, mock_async_response, monkeypatch
|
||||
):
|
||||
"""Test AsyncArcade.health.check method."""
|
||||
|
||||
async def mock_execute_request(*args, **kwargs):
|
||||
return HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA
|
||||
|
||||
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
|
||||
client = AsyncArcade(api_key="fake_api_key")
|
||||
with pytest.raises(EngineNotHealthyError):
|
||||
await client.health.check()
|
||||
await test_async_client.health.check()
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ def sample_tool() -> str:
|
|||
return "Hello, world!"
|
||||
|
||||
|
||||
def test_add_tool_with_no_toolkit():
|
||||
def test_add_tool_with_empty_toolkit_name_raises():
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(sample_tool)
|
||||
assert catalog.get_tool(FullyQualifiedName("SampleTool", "Tools", None)).tool == sample_tool
|
||||
with pytest.raises(ValueError):
|
||||
catalog.add_tool(sample_tool, "")
|
||||
|
||||
|
||||
def test_add_tool_with_toolkit_name():
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ ARG HOST=0.0.0.0
|
|||
# Set environment variables using the build arguments
|
||||
ENV PORT=${PORT}
|
||||
ENV HOST=${HOST}
|
||||
ENV WORK_DIR=/app
|
||||
ENV ARCADE_WORK_DIR=/app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
|
|
@ -45,8 +45,8 @@ WORKDIR /app/toolkits
|
|||
# Install toolkits from the toolkits directory
|
||||
RUN set -e; \
|
||||
for toolkit in ./*; do \
|
||||
echo "Installing toolkit $toolkit"; \
|
||||
pip install $toolkit; \
|
||||
echo "Installing toolkit $toolkit"; \
|
||||
pip install $toolkit; \
|
||||
done
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,68 +0,0 @@
|
|||
import os
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from langchain_google_community import GmailToolkit
|
||||
from langchain_google_community.gmail.utils import (
|
||||
build_resource_service,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
# Step 1: Install required packages
|
||||
# Run the following in your terminal:
|
||||
# %pip install -qU langchain-google-community[gmail]
|
||||
# %pip install -qU langchain-openai
|
||||
# %pip install -qU langgraph
|
||||
#
|
||||
# Step 2: Set environment variables for LangChain and OpenAI API keys
|
||||
# Uncomment the following lines if you have the LangSmith API key
|
||||
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangSmith API key: ")
|
||||
#
|
||||
# Step 3 (Option 1) Manually authenticate with Gmail by creating your own google app, credentials, and handling tokens and Oauth
|
||||
# credentials = get_gmail_credentials(
|
||||
# token_file="token.json",
|
||||
# scopes=["https://mail.google.com/"],
|
||||
# client_secrets_file="credentials.json",
|
||||
# )
|
||||
#
|
||||
# ----------------- OR -----------------
|
||||
# Step 3 (Option 2) Use the Arcade SDK to authenticate with Gmail
|
||||
from arcade.client import Arcade, AuthProvider
|
||||
|
||||
client = Arcade(api_key=os.environ["ARCADE_API_KEY"])
|
||||
|
||||
challenge = client.auth.authorize(
|
||||
provider=AuthProvider.google,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
user_id="example_user_id",
|
||||
)
|
||||
|
||||
if challenge.status != "completed":
|
||||
print(f"Please visit this URL to authorize: {challenge.auth_url}")
|
||||
input("Press Enter after you've completed the authorization...")
|
||||
challenge = client.auth.poll_authorization(challenge)
|
||||
if challenge.status != "completed":
|
||||
print("Authorization not completed. Please try again.")
|
||||
exit(1)
|
||||
|
||||
|
||||
creds = Credentials(challenge.context.token)
|
||||
api_resource = build_resource_service(credentials=creds)
|
||||
toolkit = GmailToolkit(api_resource=api_resource)
|
||||
|
||||
# Step 4: Get available tools
|
||||
tools = toolkit.get_tools()
|
||||
|
||||
# Step 5: Initialize the LLM and create an agent
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
agent_executor = create_react_agent(llm, tools)
|
||||
|
||||
# Step 6: Draft an email using the agent
|
||||
example_query = "Read my latest emails to me and summarize them."
|
||||
events = agent_executor.stream(
|
||||
{"messages": [("user", example_query)]},
|
||||
stream_mode="values",
|
||||
)
|
||||
for event in events:
|
||||
event["messages"][-1].pretty_print()
|
||||
60
examples/langchain/langgraph_auth.py
Normal file
60
examples/langchain/langgraph_auth.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import time # Import time for polling delays
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from langchain_google_community import GmailToolkit
|
||||
from langchain_google_community.gmail.utils import (
|
||||
build_resource_service,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
# Step 1: Install required packages
|
||||
# Run the following in your terminal:
|
||||
# %pip install -qU langchain-google-community[gmail]
|
||||
# %pip install -qU langchain-openai
|
||||
# %pip install -qU langgraph
|
||||
from arcade.client import Arcade, AuthProvider
|
||||
|
||||
client = Arcade()
|
||||
|
||||
# Start the authorization process for the tool "ListEmails"
|
||||
auth_response = client.auth.authorize(
|
||||
provider=AuthProvider.google,
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
user_id="sam@arcade-ai.com",
|
||||
)
|
||||
|
||||
# If authorization is not completed, prompt the user and poll for status
|
||||
if auth_response.status != "completed":
|
||||
print(
|
||||
"Please complete the authorization challenge in your browser before continuing:"
|
||||
)
|
||||
print(auth_response.auth_url)
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Poll for authorization status using the auth polling method
|
||||
while auth_response.status != "completed":
|
||||
# Wait before polling again to avoid spamming the server
|
||||
time.sleep(4)
|
||||
auth_response = client.auth.status(auth_response)
|
||||
|
||||
# Authorization is completed; proceed with obtaining credentials
|
||||
creds = Credentials(auth_response.context.token)
|
||||
api_resource = build_resource_service(credentials=creds)
|
||||
toolkit = GmailToolkit(api_resource=api_resource)
|
||||
|
||||
# Step 4: Get available tools
|
||||
tools = toolkit.get_tools()
|
||||
|
||||
# Step 5: Initialize the LLM and create an agent
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
agent_executor = create_react_agent(llm, tools)
|
||||
|
||||
# Step 6: Draft an email using the agent
|
||||
example_query = "Read my latest emails to me and summarize them."
|
||||
events = agent_executor.stream(
|
||||
{"messages": [("user", example_query)]},
|
||||
stream_mode="values",
|
||||
)
|
||||
for event in events:
|
||||
event["messages"][-1].pretty_print()
|
||||
63
examples/langchain/langgraph_with_tool_exec.py
Normal file
63
examples/langchain/langgraph_with_tool_exec.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.errors import NodeInterrupt
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
from arcade.client import Arcade
|
||||
|
||||
client = Arcade(api_key=os.environ["ARCADE_API_KEY"])
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
emails: Any
|
||||
|
||||
|
||||
def step_1(state: State, config) -> State:
|
||||
user_id = config["configurable"]["user_id"]
|
||||
|
||||
challenge = client.tools.authorize(
|
||||
tool_name="ListEmails",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if challenge.status != "completed":
|
||||
raise NodeInterrupt(f"Please visit this URL to authorize: {challenge.auth_url}")
|
||||
|
||||
result = client.tools.run(
|
||||
tool_name="ListEmails",
|
||||
user_id=user_id,
|
||||
tool_version="default",
|
||||
inputs=json.dumps({"n_emails": 5}),
|
||||
)
|
||||
return {"emails": result}
|
||||
|
||||
|
||||
builder = StateGraph(State)
|
||||
builder.add_node("step_1", step_1)
|
||||
builder.add_edge(START, "step_1")
|
||||
builder.add_edge("step_1", END)
|
||||
|
||||
# Set up memory
|
||||
memory = MemorySaver()
|
||||
|
||||
# Compile the graph with memory
|
||||
graph = builder.compile(checkpointer=memory)
|
||||
|
||||
config = {"configurable": {"thread_id": "2", "user_id": "sam@arcade-ai.com"}}
|
||||
result = graph.invoke({"emails": None}, config=config)
|
||||
state = graph.get_state({"configurable": {"thread_id": "2"}})
|
||||
print("interrupted state\n----------")
|
||||
print(state)
|
||||
print("----------")
|
||||
input()
|
||||
result = graph.invoke({"emails": None}, config=config)
|
||||
state = graph.get_state({"configurable": {"thread_id": "2"}})
|
||||
print("final state\n----------")
|
||||
print(state)
|
||||
print("----------")
|
||||
print("final result\n----------")
|
||||
print(result)
|
||||
print("----------")
|
||||
|
|
@ -2,7 +2,7 @@ import os
|
|||
|
||||
from modal import App, Image, asgi_app
|
||||
|
||||
os.environ["WORK_DIR"] = "/root"
|
||||
os.environ["ARCADE_WORK_DIR"] = "/root"
|
||||
|
||||
# Define the FastAPI app
|
||||
app = App("arcade-ai-actor")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ authors = ["Sam Partee <sam@arcade-ai.com>", "Eric Gustin <eric@arcade-ai.com>"]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
arcade-ai = "*"
|
||||
arcade-ai = "^0.1.0"
|
||||
google-api-core = "2.19.1"
|
||||
google-api-python-client = "2.137.0"
|
||||
google-auth = "2.32.0"
|
||||
|
|
@ -16,7 +16,7 @@ googleapis-common-protos = "1.63.2"
|
|||
beautifulsoup4 = "^4.10.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^7.4.0"
|
||||
pytest = "^8.3.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.toolkit import Toolkit
|
||||
import arcade_math
|
||||
from arcade_math.tools.arithmetic import add, sqrt
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
|
|
@ -18,11 +17,11 @@ rubric = EvalRubric(
|
|||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_toolkit(Toolkit.from_module(arcade_math))
|
||||
catalog.add_module(arcade_math)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def arithmetic_eval_suite():
|
||||
def math_eval_suite():
|
||||
suite = EvalSuite(
|
||||
name="Math Tools Evaluation",
|
||||
system_message="You are an AI assistant with access to math tools. Use them to help the user with their math-related tasks.",
|
||||
|
|
@ -34,9 +33,9 @@ def arithmetic_eval_suite():
|
|||
name="Add two large numbers",
|
||||
user_message="Add 12345 and 987654321",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
"Arithmetic_Add",
|
||||
args={
|
||||
(
|
||||
add,
|
||||
{
|
||||
"a": 12345,
|
||||
"b": 987654321,
|
||||
},
|
||||
|
|
@ -55,7 +54,12 @@ def arithmetic_eval_suite():
|
|||
name="Take the square root of a large number",
|
||||
user_message="What is the square root of 3224990521?",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall("Arithmetic_Sqrt", args={"a": 3224990521})
|
||||
(
|
||||
sqrt,
|
||||
{
|
||||
"a": 3224990521,
|
||||
},
|
||||
)
|
||||
],
|
||||
rubric=rubric,
|
||||
critics=[
|
||||
|
|
@ -7,10 +7,10 @@ authors = ["Nate <nate@arcade-ai.com>"]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
arcade-ai = "*"
|
||||
arcade-ai = "^0.1.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^7.4"
|
||||
pytest = "^8.3.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
|
|
|||
239
toolkits/search/evals/eval_google_search.py
Normal file
239
toolkits/search/evals/eval_google_search.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
import arcade_search
|
||||
from arcade_search.tools.google import search_google
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
NumericCritic,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Google Search tool
|
||||
catalog.add_module(arcade_search)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def google_search_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for the Google Search tool."""
|
||||
suite = EvalSuite(
|
||||
name="Google Search Tool Evaluation",
|
||||
system_message="You are an AI assistant that can perform web searches using the provided tools.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
# Simple search query with default results
|
||||
suite.add_case(
|
||||
name="Simple search query with default results",
|
||||
user_message="Search for 'Climate change effects on polar bears' on Google.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "Climate change effects on polar bears",
|
||||
"n_results": 5,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
# Search query with specific number of results
|
||||
suite.add_case(
|
||||
name="Search query with specific number of results",
|
||||
user_message="Find the top 3 articles about quantum computing.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "articles about quantum computing",
|
||||
"n_results": 3,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=0.7),
|
||||
NumericCritic(
|
||||
critic_field="n_results",
|
||||
weight=0.3,
|
||||
value_range=(1, 100),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Search query with 'n' results specified in words
|
||||
suite.add_case(
|
||||
name="Search query with 'n' results specified in words",
|
||||
user_message="Give me five recipes for vegan lasagna.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "recipes for vegan lasagna",
|
||||
"n_results": 5,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=0.7),
|
||||
NumericCritic(
|
||||
critic_field="n_results",
|
||||
weight=0.3,
|
||||
value_range=(1, 100),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Ambiguous number of results
|
||||
suite.add_case(
|
||||
name="Ambiguous number of results",
|
||||
user_message="Find articles about climate change impacts 10.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "articles about climate change impacts 10",
|
||||
"n_results": 5,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
# Search query with multiple instructions
|
||||
suite.add_case(
|
||||
name="Search query with multiple instructions",
|
||||
user_message="Search for the latest news on electric cars, and tell me about Tesla's new model.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "latest news on electric cars",
|
||||
"n_results": 5,
|
||||
},
|
||||
),
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "Tesla's new model",
|
||||
"n_results": 5,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
# Search with stop words and filler words
|
||||
suite.add_case(
|
||||
name="Search with stop words and filler words",
|
||||
user_message="Could you please search for the best ways to learn French?",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "best ways to learn French",
|
||||
"n_results": 5,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
# No clear query given
|
||||
suite.add_case(
|
||||
name="No clear query given",
|
||||
user_message="Find it for me.",
|
||||
expected_tool_calls=[],
|
||||
critics=[],
|
||||
)
|
||||
|
||||
# Search query with special characters
|
||||
suite.add_case(
|
||||
name="Search query with special characters",
|
||||
user_message="Find me '@OpenAI's latest research papers'",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "@OpenAI's latest research papers",
|
||||
"n_results": 5,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
# Search query with complex instructions
|
||||
suite.add_case(
|
||||
name="Search query with complex instructions",
|
||||
user_message="I need information about the impact of deforestation in the Amazon over the past decade.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "impact of deforestation in the Amazon over the past decade",
|
||||
"n_results": 5,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
# Search query in a different language
|
||||
suite.add_case(
|
||||
name="Search query in a different language",
|
||||
user_message="Busca información sobre la economía de España.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "economía de España",
|
||||
"n_results": 5,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
# Search query with numeric data
|
||||
suite.add_case(
|
||||
name="Search query with numeric data",
|
||||
user_message="What was the population of Japan in 2020?",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_google,
|
||||
{
|
||||
"query": "population of Japan in 2020",
|
||||
"n_results": 5,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="query", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
|
@ -6,11 +6,11 @@ authors = ["Sam Partee <sam@arcade-ai.com>"]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
arcade-ai = "*"
|
||||
arcade-ai = "^0.1.0"
|
||||
serpapi = "^0.1.5"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^7.4.0"
|
||||
pytest = "^8.3.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import arcade_slack
|
||||
from arcade_slack.tools.chat import send_dm_to_user, send_message_to_channel
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
|
|
@ -5,7 +6,6 @@ from arcade.sdk.eval import (
|
|||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
|
@ -19,8 +19,7 @@ rubric = EvalRubric(
|
|||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_tool(send_dm_to_user)
|
||||
catalog.add_tool(send_message_to_channel)
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
|
|
@ -38,9 +37,9 @@ def slack_eval_suite() -> EvalSuite:
|
|||
name="Send DM to user with clear username",
|
||||
user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
(
|
||||
send_dm_to_user,
|
||||
{
|
||||
"user_name": "johndoe",
|
||||
"message": "Hello, can we meet at 3 PM?",
|
||||
},
|
||||
|
|
@ -56,54 +55,54 @@ def slack_eval_suite() -> EvalSuite:
|
|||
name="Send DM with ambiguous username",
|
||||
user_message="Message John about the project deadline",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
(
|
||||
send_dm_to_user,
|
||||
{
|
||||
"user_name": "john",
|
||||
"message": "Hi John, I wanted to check about the project deadline. Can you provide an update?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="user_name", weight=0.6),
|
||||
SimilarityCritic(critic_field="message", weight=0.4),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send DM with username in different format",
|
||||
user_message="DM Jane.Doe to reschedule our meeting",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
"user_name": "jane.doe",
|
||||
"message": "Hi Jane, I need to reschedule our meeting. When are you available?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.6),
|
||||
SimilarityCritic(critic_field="message", weight=0.4),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send DM with username in different format",
|
||||
user_message="DM Jane.Doe to reschedule our meeting",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
send_dm_to_user,
|
||||
{
|
||||
"user_name": "jane.doe",
|
||||
"message": "Hi Jane, I need to reschedule our meeting. When are you available?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
# Send Message to Channel Scenarios
|
||||
suite.add_case(
|
||||
name="Send message to channel with clear name",
|
||||
user_message="Post 'The new feature is now live!' in the #announcements channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendMessageToChannel",
|
||||
args={
|
||||
(
|
||||
send_message_to_channel,
|
||||
{
|
||||
"channel_name": "announcements",
|
||||
"message": "The new feature is now live!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="channel_name", weight=0.6),
|
||||
SimilarityCritic(critic_field="message", weight=0.4),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -111,9 +110,9 @@ def slack_eval_suite() -> EvalSuite:
|
|||
name="Send message to channel with ambiguous name",
|
||||
user_message="Inform the engineering team about the upcoming maintenance in the general channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendMessageToChannel",
|
||||
args={
|
||||
(
|
||||
send_message_to_channel,
|
||||
{
|
||||
"channel_name": "engineering",
|
||||
"message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.",
|
||||
},
|
||||
|
|
@ -130,9 +129,9 @@ def slack_eval_suite() -> EvalSuite:
|
|||
name="Ambiguous between DM and channel message",
|
||||
user_message="Send 'Great job on the presentation!' to the team",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendMessageToChannel",
|
||||
args={
|
||||
(
|
||||
send_message_to_channel,
|
||||
{
|
||||
"channel_name": "general",
|
||||
"message": "Great job on the presentation!",
|
||||
},
|
||||
|
|
@ -149,25 +148,25 @@ def slack_eval_suite() -> EvalSuite:
|
|||
name="Multiple recipients in DM request",
|
||||
user_message="Send a DM to Alice and Bob about pushing the meeting tomorrow. I have to much work to do.",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
(
|
||||
send_dm_to_user,
|
||||
{
|
||||
"user_name": "alice",
|
||||
"message": "Hi Alice, about our meeting tomorrow, let's reschedule? I am swamped with work.",
|
||||
},
|
||||
),
|
||||
ExpectedToolCall(
|
||||
name="SendDmToUser",
|
||||
args={
|
||||
(
|
||||
send_dm_to_user,
|
||||
{
|
||||
"user_name": "bob",
|
||||
"message": "Hi Bob, about our meeting tomorrow, let's reschedule? I am swamped with work.",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="user_name", weight=0.6),
|
||||
SimilarityCritic(critic_field="user_name", weight=0.7),
|
||||
SimilarityCritic(
|
||||
critic_field="message", weight=0.4, similarity_threshold=0.7
|
||||
critic_field="message", weight=0.3, similarity_threshold=0.6
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -176,9 +175,9 @@ def slack_eval_suite() -> EvalSuite:
|
|||
name="Channel name similar to username",
|
||||
user_message="Post 'sounds great!' in john-project channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="SendMessageToChannel",
|
||||
args={
|
||||
(
|
||||
send_message_to_channel,
|
||||
{
|
||||
"channel_name": "john-project",
|
||||
"message": "Sounds great!",
|
||||
},
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ arcade-ai = "^0.1.0"
|
|||
slack-sdk = "^3.31.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^7.4.0"
|
||||
pytest = "^8.3.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
|
|
|||
|
|
@ -1,17 +1,16 @@
|
|||
import arcade_x
|
||||
from arcade_x.tools.tweets import post_tweet
|
||||
|
||||
# TODO
|
||||
# delete_tweet_by_id,
|
||||
# search_recent_tweets_by_keywords,
|
||||
# search_recent_tweets_by_username,
|
||||
# from arcade_x.tools.users import lookup_single_user_by_username
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade_x.tools.tweets import (
|
||||
post_tweet,
|
||||
delete_tweet_by_id,
|
||||
# search_recent_tweets_by_query,
|
||||
search_recent_tweets_by_username,
|
||||
search_recent_tweets_by_keywords,
|
||||
)
|
||||
from arcade_x.tools.users import lookup_single_user_by_username
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
|
|
@ -22,11 +21,8 @@ rubric = EvalRubric(
|
|||
)
|
||||
|
||||
catalog = ToolCatalog()
|
||||
catalog.add_tool(search_recent_tweets_by_keywords)
|
||||
catalog.add_tool(lookup_single_user_by_username)
|
||||
catalog.add_tool(post_tweet)
|
||||
catalog.add_tool(delete_tweet_by_id)
|
||||
catalog.add_tool(search_recent_tweets_by_username)
|
||||
# Register the X tools
|
||||
catalog.add_module(arcade_x)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
|
|
@ -45,17 +41,18 @@ def x_eval_suite() -> EvalSuite:
|
|||
name="Post a tweet",
|
||||
user_message="Send out a tweet that says 'Hello World! Exciting stuff is happening over at Arcade AI!'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
name="PostTweet",
|
||||
args={
|
||||
(
|
||||
post_tweet,
|
||||
{
|
||||
"tweet_text": "Hello World! Exciting stuff is happening over at Arcade AI!"
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(
|
||||
SimilarityCritic(
|
||||
critic_field="tweet_text",
|
||||
weight=1.0,
|
||||
similarity_threshold=0.9,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue