Client Fixes and LangGraph Examples (#50)

This PR includes several improvements to the Arcade client and adds
LangGraph examples:

1. Enhanced error handling in the Arcade client:
   - Improved HTTP error handling in `BaseArcadeClient`
- Simplified request methods in `SyncArcadeClient` and
`AsyncArcadeClient`

2. Updated `ToolResource` class:
   - Changed base path from `/v1/tool` to `/v1/tools`
   - Added `tool_version` parameter to `authorize` method

3. Improved Toolkit discovery:
- Updated `find_all_arcade_toolkits` to search only in the current
Python interpreter's site-packages

5. Added LangGraph examples:
   - New `langgraph_auth.py` example demonstrating Gmail authentication
- New `langgraph_with_tool_exec.py` example showing tool execution
within a LangGraph

6. Minor updates:
   - Changed default `BASE_URL` to `https://api.arcade.com/`
   - Updated import error message for eval dependencies

---------

Co-authored-by: Nate Barbettini <nate@arcade-ai.com>
This commit is contained in:
Sam Partee 2024-09-24 10:13:45 -07:00 committed by GitHub
parent 8d66b52512
commit 2eb46a3a98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 1291 additions and 403 deletions

6
.vscode/launch.json vendored
View file

@ -34,15 +34,15 @@
"cwd": "${workspaceFolder}"
},
{
"name": "Debug `arcade evals -d`",
"name": "Debug `arcade evals -d` on current file",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/arcade/run_cli.py",
"args": ["evals", "-d"],
"args": ["evals", "-d", "${fileDirname}", "-h", "localhost"],
"console": "integratedTerminal",
"jinja": true,
"justMyCode": true,
"cwd": "${workspaceFolder}"
"cwd": ""
}
]
}

View file

@ -73,7 +73,7 @@ class BaseActor(Actor):
"""
return [tool.definition for tool in self.catalog]
def register_tool(self, tool: Callable, toolkit_name: str | None = None) -> None:
def register_tool(self, tool: Callable, toolkit_name: str) -> None:
"""
Register a tool to the catalog.
"""

View file

@ -0,0 +1,378 @@
import io
import ipaddress
import logging
import os
import shutil
import signal
import subprocess
import sys
import threading
import time
from pathlib import Path
from typing import Callable
from rich.console import Console
console = Console(highlight=False)
logger = logging.getLogger(__name__)
def start_servers(
host: str,
port: int,
engine_config: str | None,
) -> None:
"""
Start the actor and engine servers.
Args:
host: Host for the actor server.
port: Port for the actor server.
engine_config: Path to the engine configuration file.
"""
# Validate host and port
host = _validate_host(host)
port = _validate_port(port)
# Ensure engine_config is provided and validated
engine_config = _get_engine_config(engine_config)
# Prepare command-line arguments for the actor server and engine
actor_cmd = _build_actor_command(host, port)
engine_cmd = _build_engine_command(engine_config)
# Start and manage the processes
_manage_processes(actor_cmd, engine_cmd)
def _validate_host(host: str) -> str:
"""
Validates the host input.
Args:
host: Host for the actor server.
Returns:
The validated host as a string.
Raises:
ValueError: If the host is invalid.
"""
try:
# Validate IP address
ipaddress.ip_address(host)
except ValueError:
# Optionally, validate hostname
if not host.isalnum() and "-" not in host and "." not in host:
console.print(f"❌ Invalid host: {host}", style="bold red")
raise ValueError("Invalid host.")
return host
def _validate_port(port: int) -> int:
"""
Validates the port input.
Args:
port: Port for the actor server.
Returns:
The validated port as an integer.
Raises:
ValueError: If the port is out of the valid range.
"""
if not (1 <= port <= 65535):
console.print(f"❌ Invalid port: {port}", style="bold red")
raise ValueError("Invalid port.")
return port
def _get_engine_config(engine_config: str | None) -> str:
"""
Determines and validates the engine config file path.
Args:
engine_config: Optional path provided by the user.
Returns:
The resolved engine config file path.
Raises:
RuntimeError: If the config file is not found or invalid.
"""
if engine_config:
engine_config_path = Path(os.path.expanduser(engine_config)).resolve()
if not engine_config_path.is_file():
console.print(
f"❌ Engine config file not found at {engine_config_path}", style="bold red"
)
raise RuntimeError("Engine config file not found.")
else:
# Look for engine.yaml in the current directory
engine_config_path = Path(os.getcwd()) / "engine.yaml"
if not engine_config_path.is_file():
console.print(
"❌ Engine config file not specified and not found in current directory.",
style="bold red",
)
raise RuntimeError("Engine config file not specified.")
return str(engine_config_path)
def _build_actor_command(host: str, port: int) -> list[str]:
"""
Builds the command to start the actor server.
Args:
host: Host for the actor server.
port: Port for the actor server.
Returns:
The command as a list.
"""
# Expand full path to "arcade" executable
arcade_bin = shutil.which("arcade")
if not arcade_bin:
console.print(
"❌ Arcade binary not found, please install with `pip install arcade-ai`",
style="bold red",
)
sys.exit(1)
cmd = [
arcade_bin,
"dev",
"--host",
host,
"--port",
str(port),
]
return cmd
def _build_engine_command(engine_config: str) -> list[str]:
"""
Builds the command to start the engine.
Args:
engine_config: Path to the engine configuration file.
Returns:
The command as a list.
"""
engine_bin = shutil.which("engine")
if not engine_bin:
console.print(
"❌ Engine binary not found, refer to the installation guide at "
"https://docs.arcade-ai.com/docs/home/deployment for how to install the engine",
style="bold red",
)
sys.exit(1)
cmd = [
engine_bin,
"dev",
"-c",
engine_config,
]
return cmd
def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
"""
Manages the lifecycle of the actor and engine processes.
Args:
actor_cmd: The command to start the actor server.
engine_cmd: The command to start the engine.
"""
actor_process: subprocess.Popen | None = None
engine_process: subprocess.Popen | None = None
def terminate_processes(exit_program: bool = False) -> None:
console.print("Terminating child processes...", style="bold yellow")
_terminate_process(actor_process)
_terminate_process(engine_process)
if exit_program:
sys.exit(0)
_setup_signal_handlers(terminate_processes)
retry_count = 0
max_retries = 3 # Define the maximum number of retries
while retry_count <= max_retries:
try:
# Start the actor server
console.print("Starting actor server...", style="bold green")
actor_process = _start_process("Actor", actor_cmd)
# Wait a bit to ensure actor is up
time.sleep(2)
# Start the engine
console.print("Starting engine...", style="bold green")
engine_process = _start_process("Engine", engine_cmd)
# Monitor processes
_monitor_processes(actor_process, engine_process)
# If we reach here, one of the processes has exited
retry_count += 1
console.print(
f"Processes exited. Retry {retry_count} of {max_retries}.", style="bold yellow"
)
if retry_count > max_retries:
console.print(f"❌ Exiting after {retry_count - 1} retries", style="bold red")
terminate_processes(exit_program=True)
break # Exit the loop
except Exception as e:
console.print(f"❌ Exception occurred: {e}", style="bold red")
terminate_processes()
retry_count += 1
if retry_count > max_retries:
console.print(
f"❌ Exiting after {retry_count - 1} retries due to exceptions",
style="bold red",
)
sys.exit(1)
break # Not strictly necessary, but good practice
console.print("Exiting...", style="bold red")
sys.exit(1)
def _start_process(name: str, cmd: list[str]) -> subprocess.Popen:
"""
Starts a subprocess and begins streaming its output.
Args:
name: Name of the process.
cmd: Command to execute.
Returns:
The subprocess.Popen object.
Raises:
RuntimeError: If the process fails to start.
"""
try:
process = subprocess.Popen( # noqa: S603, RUF100
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
bufsize=1,
shell=False,
)
_stream_output(process, name)
return process # noqa: TRY300
except Exception as e:
console.print(f"❌ Failed to start {name}: {e}", style="bold red")
raise RuntimeError(f"Failed to start {name}")
def _stream_output(process: subprocess.Popen, name: str) -> None:
"""
Streams the output from a subprocess to the console.
Args:
process: The subprocess.Popen object.
name: Name of the process.
"""
stdout_style = "green" if name == "Actor" else "#87CEFA"
def stream(pipe: io.TextIOWrapper | None, style: str) -> None:
if pipe is None:
return
with pipe:
for line in iter(pipe.readline, ""):
console.print(f"[{style}]{name}>[/{style}] {line.rstrip()}")
threading.Thread(target=stream, args=(process.stdout, stdout_style), daemon=True).start()
threading.Thread(target=stream, args=(process.stderr, "red"), daemon=True).start()
def _monitor_processes(actor_process: subprocess.Popen, engine_process: subprocess.Popen) -> None:
"""
Monitors the actor and engine processes, restarts them if they exit.
Args:
actor_process: The actor subprocess.
engine_process: The engine subprocess.
"""
while True:
actor_status = actor_process.poll()
engine_status = engine_process.poll()
if actor_status is not None or engine_status is not None:
if actor_status is not None:
console.print(
f"Actor process exited with code {actor_status}. Restarting both processes...",
style="bold red",
)
if engine_status is not None:
console.print(
f"Engine process exited with code {engine_status}. Restarting both processes...",
style="bold red",
)
_terminate_process(actor_process)
_terminate_process(engine_process)
time.sleep(1)
break # Exit to restart both processes
else:
time.sleep(1)
def _terminate_process(process: subprocess.Popen | None) -> None:
"""
Terminates a subprocess if it's running.
Args:
process: The subprocess.Popen object.
"""
if process and process.poll() is None:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
def _setup_signal_handlers(terminate_processes: Callable[[bool], None]) -> None:
"""
Setup signal handlers to handle process termination signals.
Args:
terminate_processes: Function to call to terminate child processes.
"""
signals_to_handle = ["SIGINT", "SIGTERM", "SIGQUIT", "SIGHUP"]
for sig_name in signals_to_handle:
sig = getattr(signal, sig_name, None)
if sig is None:
continue # Signal not available on this platform
try:
# Use a lambda to pass the terminate_processes function
signal.signal(
sig,
lambda signum, frame: _handle_signal(signum, terminate_processes),
)
except (ValueError, RuntimeError):
# Signal handling not allowed in this thread or invalid signal
console.print(f"Warning: Cannot set handler for {sig_name}", style="bold yellow")
continue
def _handle_signal(signum: int, terminate_processes: Callable[[bool], None]) -> None:
"""
Handle received signal and terminate child processes.
Args:
signum: The signal number received.
terminate_processes: Function to call to terminate child processes.
"""
signal_name = signal.Signals(signum).name
console.print(f"Received {signal_name}. Shutting down...", style="bold yellow")
terminate_processes(exit_program=True) # type: ignore[call-arg]

View file

@ -15,6 +15,7 @@ from rich.table import Table
from rich.text import Text
from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login
from arcade.cli.launcher import start_servers
from arcade.cli.utils import (
OrderCommands,
apply_config_overrides,
@ -28,14 +29,41 @@ from arcade.cli.utils import (
)
from arcade.client import Arcade
from arcade.client.errors import EngineNotHealthyError, EngineOfflineError
from arcade.core.config_model import Config
cli = typer.Typer(
cls=OrderCommands,
add_completion=False,
no_args_is_help=True,
pretty_exceptions_enable=False,
pretty_exceptions_show_locals=False,
pretty_exceptions_short=True,
)
console = Console()
@cli.command(help="Log in to Arcade Cloud")
def _get_config_with_overrides(
force_tls: bool,
force_no_tls: bool,
host_input: str | None = None,
port_input: int | None = None,
) -> Config:
"""
Get the config with CLI-specific optional overrides applied.
"""
config = validate_and_get_config()
if not force_tls and not force_no_tls:
tls_input = None
elif force_no_tls:
tls_input = False
else:
tls_input = True
apply_config_overrides(config, host_input, port_input, tls_input)
return config
@cli.command(help="Log in to Arcade Cloud", rich_help_panel="User")
def login(
host: str = typer.Option(
"cloud.arcade-ai.com",
@ -74,7 +102,7 @@ def login(
server_thread.join() # Ensure the server thread completes and cleans up
@cli.command(help="Log out of Arcade Cloud")
@cli.command(help="Log out of Arcade Cloud", rich_help_panel="User")
def logout() -> None:
"""
Logs the user out of Arcade Cloud.
@ -89,7 +117,7 @@ def logout() -> None:
console.print("You're not logged in.", style="bold red")
@cli.command(help="Create a new toolkit package directory")
@cli.command(help="Create a new toolkit package directory", rich_help_panel="Tool Development")
def new(
directory: str = typer.Option(os.getcwd(), "--dir", help="tools directory path"),
) -> None:
@ -105,7 +133,10 @@ def new(
console.print(error_message, style="bold red")
@cli.command(help="Show the available tools in an actor or toolkit directory")
@cli.command(
help="Show the installed toolkits",
rich_help_panel="Tool Development",
)
def show(
toolkit: Optional[str] = typer.Option(
None, "-t", "--toolkit", help="The toolkit to show the tools of"
@ -139,12 +170,13 @@ def show(
console.print(error_message, style="bold red")
@cli.command(help="Chat with a language model")
@cli.command(help="Start Arcade Chat in the terminal", rich_help_panel="Launch")
def chat(
model: str = typer.Option("gpt-4o", "-m", help="The model to use for prediction."),
stream: bool = typer.Option(
False, "-s", "--stream", is_flag=True, help="Stream the tool output."
),
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
host: str = typer.Option(
None,
"-h",
@ -167,20 +199,11 @@ def chat(
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
),
debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"),
) -> None:
"""
Chat with a language model.
"""
config = validate_and_get_config()
if not force_tls and not force_no_tls:
tls_input = None
elif force_no_tls:
tls_input = False
else:
tls_input = True
apply_config_overrides(config, host, port, tls_input)
config = _get_config_with_overrides(force_tls, force_no_tls, host, port)
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
user_email = config.user.email if config.user else None
@ -276,7 +299,7 @@ def chat(
raise typer.Exit()
@cli.command(help="Start an Actor server with specified configurations.")
@cli.command(help="Start a local Arcade Actor server", rich_help_panel="Launch")
def dev(
host: str = typer.Option(
"127.0.0.1", help="Host for the app, from settings by default.", show_default=True
@ -300,7 +323,6 @@ def dev(
try:
serve_default_actor(host, port, disable_auth)
except KeyboardInterrupt:
console.print("actor stopped by user.", style="bold red")
typer.Exit()
except Exception as e:
error_message = f"❌ Failed to start Arcade Actor: {escape(str(e))}"
@ -308,7 +330,7 @@ def dev(
raise typer.Exit(code=1)
@cli.command(help="Show/edit configuration details of the Arcade Engine")
@cli.command(help="Show/edit the local Arcade configuration", rich_help_panel="User")
def config(
action: str = typer.Argument("show", help="The action to take (show/edit)"),
key: str = typer.Option(
@ -396,7 +418,7 @@ def display_config_as_table(config) -> None: # type: ignore[no-untyped-def]
console.print(table)
@cli.command(help="Run evaluation suites in a directory")
@cli.command(help="Run tool calling evaluations", rich_help_panel="Tool Development")
def evals(
directory: str = typer.Argument(".", help="Directory containing evaluation files"),
show_details: bool = typer.Option(False, "--details", "-d", help="Show detailed results"),
@ -409,11 +431,35 @@ def evals(
models: str = typer.Option(
"gpt-4o", "--models", "-m", help="The models to use for evaluation (default: gpt-4o)"
),
host: str = typer.Option(
None,
"-h",
"--host",
help="The Arcade Engine address to send chat requests to.",
),
port: int = typer.Option(
None,
"-p",
"--port",
help="The port of the Arcade Engine.",
),
force_tls: bool = typer.Option(
False,
"--tls",
help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.",
),
force_no_tls: bool = typer.Option(
False,
"--no-tls",
help="Whether to disable TLS for the connection to the Arcade Engine.",
),
) -> None:
"""
Find all files starting with 'eval_' in the given directory,
execute any functions decorated with @tool_eval, and display the results.
"""
config = _get_config_with_overrides(force_tls, force_no_tls, host, port)
models = models.split(",") # type: ignore[assignment]
eval_files = [f for f in os.listdir(directory) if f.startswith("eval_") and f.endswith(".py")]
@ -421,6 +467,18 @@ def evals(
console.print("No evaluation files found.", style="bold yellow")
return
if show_details:
console.print(
Text.assemble(
("\nRunning evaluations against Arcade Engine at ", "bold"),
(config.engine_url, "bold blue"),
)
)
# Try to hit /health endpoint on engine and warn if it is down
client = Arcade(api_key=config.api.key, base_url=config.engine_url)
log_engine_health(client)
for file in eval_files:
file_path = os.path.join(directory, file)
module_name = file[:-3] # Remove .py extension
@ -432,17 +490,47 @@ def evals(
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore[union-attr]
eval_functions = [
eval_suites = [
obj
for name, obj in module.__dict__.items()
if callable(obj) and hasattr(obj, "__tool_eval__")
]
if not eval_functions:
if not eval_suites:
console.print(f"No @tool_eval functions found in {file}", style="bold yellow")
continue
for func in eval_functions:
console.print(f"\nRunning evaluation from {file}: {func.__name__}", style="bold blue")
results = func(models=models, max_concurrency=max_concurrent)
if show_details:
suite_label = "suite" if len(eval_suites) == 1 else "suites"
console.print(f"\nFound {len(eval_suites)} {suite_label} in {file}", style="bold")
for suite_func in eval_suites:
console.print(
Text.assemble(
("\nRunning evaluations in ", "bold"),
(suite_func.__name__, "bold blue"),
)
)
results = suite_func(config=config, models=models, max_concurrency=max_concurrent)
display_eval_results(results, show_details=show_details)
@cli.command(help="Start an Arcade Cluster instance", rich_help_panel="Launch")
def up(
host: str = typer.Option("127.0.0.1", help="Host for the actor server.", show_default=True),
port: int = typer.Option(
8002, "-p", "--port", help="Port for the actor server.", show_default=True
),
engine_config: str = typer.Option(
None, "-c", "--config", help="Path to the engine configuration file."
),
) -> None:
"""
Start both the actor and engine servers.
"""
try:
start_servers(host, port, engine_config)
except Exception as e:
error_message = f"❌ Failed to start servers: {escape(str(e))}"
console.print(error_message, style="bold red")
raise typer.Exit(code=1)

View file

@ -1,7 +1,11 @@
import asyncio
import logging
import os
import sys
from contextlib import asynccontextmanager
from typing import Any
from rich.console import Console
from loguru import logger
try:
import fastapi
@ -18,29 +22,73 @@ except ImportError:
from arcade.actor.fastapi.actor import FastAPIActor
from arcade.core.toolkit import Toolkit
DEVELOPMENT_SECRET = "dev" # noqa: S105
logger = logging.getLogger(__name__)
console = Console()
class InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno # type: ignore[assignment]
# Find caller from where originated the logged message
frame, depth = sys._getframe(6), 6
while frame and frame.f_code.co_filename == logging.__file__:
frame = frame.f_back # type: ignore[assignment]
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
def setup_logging(log_level: int = logging.INFO) -> None:
# Intercept everything at the root logger
logging.root.handlers = [InterceptHandler()]
logging.root.setLevel(log_level)
# Remove every other logger's handlers
# and propagate to root logger
for name in logging.root.manager.loggerDict:
logging.getLogger(name).handlers = []
logging.getLogger(name).propagate = True
# Configure loguru with custom format, no colors
logger.configure(
handlers=[
{
"sink": sys.stdout,
"serialize": False,
"level": log_level,
"format": "{time:MM-DD HH:mm:ss} | {level: <8} | {message}"
+ (" {name}:{function}:{line}" if log_level <= logging.DEBUG else "")
+ ("{exception}\n" if "{exception}" in "{message}" else ""),
}
]
)
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI): # type: ignore[no-untyped-def]
try:
yield
except asyncio.CancelledError:
# This is necessary to prevent an unhandled error
# when the user presses Ctrl+C
logger.debug("Lifespan cancelled.")
def serve_default_actor(
host: str = "127.0.0.1", port: int = 8000, disable_auth: bool = False
host: str = "127.0.0.1",
port: int = 8002,
disable_auth: bool = False,
workers: int = 1,
timeout_keep_alive: int = 5,
**kwargs: Any,
) -> None:
"""
Get an instance of a FastAPI server with the Arcade Actor.
"""
# Use Uvicorn's default log config for Arcade logging,
# to ensure a nice consistent style for all logs.
logging_config = uvicorn.config.LOGGING_CONFIG
logging_config["loggers"]["arcade"] = {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
}
# TODO: Pass in a logging config from the CLI, to set the log level.
logging.config.dictConfig(logging_config)
# Setup unified logging
setup_logging()
toolkits = Toolkit.find_all_arcade_toolkits()
if not toolkits:
@ -56,12 +104,13 @@ def serve_default_actor(
logger.warning(
"Warning: ARCADE_ACTOR_SECRET environment variable is not set. Using 'dev' as the actor secret.",
)
actor_secret = DEVELOPMENT_SECRET
actor_secret = actor_secret or "dev"
app = fastapi.FastAPI(
title="Arcade AI Actor",
description="Arcade AI default Actor implementation using FastAPI.",
version="0.1.0",
lifespan=lifespan, # Use custom lifespan to catch errors, notably KeyboardInterrupt (Ctrl+C)
)
actor = FastAPIActor(app, secret=actor_secret, disable_auth=disable_auth)
for toolkit in toolkits:
@ -69,9 +118,27 @@ def serve_default_actor(
logger.info("Starting FastAPI server...")
uvicorn.run(
class CustomUvicornServer(uvicorn.Server):
def install_signal_handlers(self) -> None:
pass # Disable Uvicorn's default signal handlers
config = uvicorn.Config(
app=app,
host=host,
port=port,
log_config=logging_config,
workers=workers,
timeout_keep_alive=timeout_keep_alive,
log_config=None,
**kwargs,
)
server = CustomUvicornServer(config=config)
async def serve() -> None:
await server.serve()
try:
asyncio.run(serve())
except KeyboardInterrupt:
logger.info("Server stopped by user.")
finally:
logger.debug("Server shutdown complete.")

View file

@ -249,18 +249,21 @@ def _format_evaluation(evaluation: "EvaluationResult") -> str:
A formatted string representation of the evaluation details.
"""
result_lines = []
for critic_result in evaluation.results:
match_color = "green" if critic_result["match"] else "red"
field = critic_result["field"]
score = critic_result["score"]
weight = critic_result["weight"]
expected = critic_result["expected"]
actual = critic_result["actual"]
result_lines.append(
f"[bold]{field}:[/bold] "
f"[{match_color}]Match: {critic_result['match']}, "
f"Score: {score:.2f}/{weight:.2f}[/{match_color}]"
f"\n Expected: {expected}"
f"\n Actual: {actual}"
)
if evaluation.failure_reason:
result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}")
else:
for critic_result in evaluation.results:
match_color = "green" if critic_result["match"] else "red"
field = critic_result["field"]
score = critic_result["score"]
weight = critic_result["weight"]
expected = critic_result["expected"]
actual = critic_result["actual"]
result_lines.append(
f"[bold]{field}:[/bold] "
f"[{match_color}]Match: {critic_result['match']}, "
f"Score: {score:.2f}/{weight:.2f}[/{match_color}]"
f"\n Expected: {expected}"
f"\n Actual: {actual}"
)
return "\n".join(result_lines)

View file

@ -1,4 +1,3 @@
import os
from typing import Any, Generic, TypeVar
from urllib.parse import urljoin
@ -13,19 +12,20 @@ from arcade.client.errors import (
RateLimitError,
UnauthorizedError,
)
from arcade.client.schema import OPENAI_API_VERSION
T = TypeVar("T")
ResponseT = TypeVar("ResponseT")
API_VERSION = "v1"
BASE_URL = "http://localhost:9099"
class BaseResource(Generic[T]):
"""Base class for all resources."""
def __init__(self, client: T):
_path: str
def __init__(self, client: T) -> None:
self._client = client
self._resource_path = self._client._base_url + self._path # type: ignore[attr-defined]
class BaseArcadeClient:
@ -33,7 +33,7 @@ class BaseArcadeClient:
def __init__(
self,
base_url: str = BASE_URL,
base_url: str | None = None,
api_key: str | None = None,
headers: dict[str, str] | None = None,
timeout: float | Timeout = 10.0,
@ -49,8 +49,14 @@ class BaseArcadeClient:
timeout: Request timeout in seconds.
retries: Number of retries for failed requests.
"""
if base_url is None or api_key is None:
from arcade.core.config import config
base_url = base_url or config.engine_url
api_key = api_key or config.api.key
self._base_url = base_url
self._api_key = api_key or os.environ.get("ARCADE_API_KEY")
self._api_key = api_key
self._headers = headers or {}
self._headers.setdefault("Authorization", f"Bearer {self._api_key}")
self._headers.setdefault("Content-Type", "application/json")
@ -65,8 +71,8 @@ class BaseArcadeClient:
def _chat_url(self, base_url: str) -> str:
chat_url = str(base_url)
if not base_url.endswith(API_VERSION):
chat_url = f"{base_url}/{API_VERSION}"
if not base_url.endswith(OPENAI_API_VERSION):
chat_url = f"{base_url}/{OPENAI_API_VERSION}"
return chat_url
def _handle_http_error(self, e: httpx.HTTPStatusError) -> None:
@ -80,7 +86,10 @@ class BaseArcadeClient:
}
status_code = e.response.status_code
error_class = error_map.get(status_code, InternalServerError)
raise error_class(str(e), response=e.response)
msg = e.response.json()
if isinstance(msg, dict) and "error" in msg:
raise error_class(msg["error"], response=e.response) from None
raise error_class(msg, response=e.response) from None
class SyncArcadeClient(BaseArcadeClient):
@ -94,7 +103,7 @@ class SyncArcadeClient(BaseArcadeClient):
timeout=self._timeout,
)
def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response:
def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response: # type: ignore[return]
"""
Make a synchronous HTTP request.
"""
@ -104,10 +113,9 @@ class SyncArcadeClient(BaseArcadeClient):
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
return response # noqa: TRY300
except httpx.HTTPStatusError:
except httpx.HTTPStatusError as e:
if attempt == self._retries - 1:
raise
raise RuntimeError("This should never be reached")
self._handle_http_error(e)
def close(self) -> None:
"""Close the client session."""
@ -139,7 +147,7 @@ class AsyncArcadeClient(BaseArcadeClient):
)
return self._client
async def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response:
async def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response: # type: ignore[return]
"""
Make an asynchronous HTTP request.
"""
@ -150,10 +158,9 @@ class AsyncArcadeClient(BaseArcadeClient):
response = await client.request(method, url, **kwargs)
response.raise_for_status()
return response # noqa: TRY300
except httpx.HTTPStatusError:
except httpx.HTTPStatusError as e:
if attempt == self._retries - 1:
raise
raise RuntimeError("This should never be reached")
self._handle_http_error(e)
async def close(self) -> None:
"""Close the client session."""

View file

@ -1,11 +1,9 @@
from typing import Any, TypeVar, Union
import httpx
from openai import AsyncOpenAI, OpenAI
from openai.resources.chat import AsyncChat, Chat
from arcade.client.base import (
API_VERSION,
AsyncArcadeClient,
BaseResource,
SyncArcadeClient,
@ -27,7 +25,7 @@ ClientT = TypeVar("ClientT", SyncArcadeClient, AsyncArcadeClient)
class AuthResource(BaseResource[ClientT]):
"""Authentication resource."""
_base_path = f"/{API_VERSION}/auth"
_path = "/auth"
def authorize(
self,
@ -59,7 +57,7 @@ class AuthResource(BaseResource[ClientT]):
data = self._client._execute_request( # type: ignore[attr-defined]
"POST",
f"{self._base_path}/authorize",
f"{self._resource_path}/authorize",
json=body,
)
return AuthResponse(**data)
@ -85,7 +83,7 @@ class AuthResource(BaseResource[ClientT]):
data = self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"{self._base_path}/status",
f"{self._resource_path}/status",
params={"authorizationId": auth_id, "scopes": " ".join(scopes) if scopes else None},
)
return AuthResponse(**data)
@ -94,7 +92,7 @@ class AuthResource(BaseResource[ClientT]):
class ToolResource(BaseResource[ClientT]):
"""Tool resource."""
_base_path = f"/{API_VERSION}/tool"
_path = "/tools"
def run(
self,
@ -119,7 +117,7 @@ class ToolResource(BaseResource[ClientT]):
"inputs": inputs,
}
data = self._client._execute_request( # type: ignore[attr-defined]
"POST", f"{self._base_path}/execute", json=request_data
"POST", f"{self._resource_path}/execute", json=request_data
)
return ExecuteToolResponse(**data)
@ -129,19 +127,21 @@ class ToolResource(BaseResource[ClientT]):
"""
data = self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"{self._base_path}/definition",
f"{self._resource_path}/definition",
params={"directorId": director_id, "toolId": tool_id},
)
return ToolDefinition(**data)
def authorize(self, tool_name: str, user_id: str) -> AuthResponse:
def authorize(
self, tool_name: str, user_id: str, tool_version: str | None = None
) -> AuthResponse:
"""
Get the authorization status for a tool.
"""
data = self._client._execute_request( # type: ignore[attr-defined]
"POST",
f"{self._base_path}/authorize",
json={"tool_name": tool_name, "user_id": user_id},
f"{self._resource_path}/authorize",
json={"tool_name": tool_name, "tool_version": tool_version, "user_id": user_id},
)
return AuthResponse(**data)
@ -149,6 +149,8 @@ class ToolResource(BaseResource[ClientT]):
class HealthResource(BaseResource[ClientT]):
"""Health check resource."""
_path = "/health"
def check(self) -> None:
"""
Check the health of the Arcade Engine.
@ -158,7 +160,7 @@ class HealthResource(BaseResource[ClientT]):
try:
data = self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"/{API_VERSION}/health",
f"{self._resource_path}",
timeout=5,
)
@ -184,7 +186,7 @@ class HealthResource(BaseResource[ClientT]):
class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
"""Asynchronous Authentication resource."""
_base_path = f"/{API_VERSION}/auth"
_path = "/auth"
async def authorize(
self,
@ -210,7 +212,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
data = await self._client._execute_request( # type: ignore[attr-defined]
"POST",
f"{self._base_path}/authorize",
f"{self._resource_path}/authorize",
json=body,
)
return AuthResponse(**data)
@ -236,7 +238,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
data = await self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"{self._base_path}/status",
f"{self._resource_path}/status",
params={"authorizationId": auth_id, "scopes": " ".join(scopes) if scopes else None},
)
return AuthResponse(**data)
@ -245,7 +247,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
class AsyncToolResource(BaseResource[AsyncArcadeClient]):
"""Asynchronous Tool resource."""
_base_path = f"/{API_VERSION}/tools"
_path = "/tools"
async def run(
self,
@ -264,7 +266,7 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
"inputs": inputs,
}
data = await self._client._execute_request( # type: ignore[attr-defined]
"POST", f"{self._base_path}/execute", json=request_data
"POST", f"{self._resource_path}/execute", json=request_data
)
return ExecuteToolResponse(**data)
@ -274,19 +276,21 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
"""
data = await self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"{self._base_path}/definition",
f"{self._resource_path}/definition",
params={"directorId": director_id, "toolId": tool_id},
)
return ToolDefinition(**data)
async def authorize(self, tool_name: str, user_id: str) -> AuthResponse:
async def authorize(
self, tool_name: str, user_id: str, tool_version: str | None = None
) -> AuthResponse:
"""
Get the authorization status for a tool.
"""
data = await self._client._execute_request( # type: ignore[attr-defined]
"POST",
f"{self._base_path}/authorize",
json={"tool_name": tool_name, "user_id": user_id},
f"{self._resource_path}/authorize",
json={"tool_name": tool_name, "tool_version": tool_version, "user_id": user_id},
)
return AuthResponse(**data)
@ -294,6 +298,8 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
class AsyncHealthResource(BaseResource[AsyncArcadeClient]):
"""Asynchronous Health check resource."""
_path = "/health"
async def check(self) -> None:
"""
Check the health of the Arcade Engine.
@ -303,7 +309,7 @@ class AsyncHealthResource(BaseResource[AsyncArcadeClient]):
try:
data = await self._client._execute_request( # type: ignore[attr-defined]
"GET",
f"/{API_VERSION}/health",
f"{self._resource_path}",
timeout=5,
)
@ -332,7 +338,7 @@ class Arcade(SyncArcadeClient):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.auth: AuthResource = AuthResource(self)
self.tool: ToolResource = ToolResource(self)
self.tools: ToolResource = ToolResource(self)
self.health: HealthResource = HealthResource(self)
chat_url = self._chat_url(self._base_url)
self._openai_client = OpenAI(base_url=chat_url, api_key=self._api_key)
@ -345,11 +351,8 @@ class Arcade(SyncArcadeClient):
"""
Execute a synchronous request.
"""
try:
response = self._request(method, url, **kwargs)
return response.json()
except httpx.HTTPStatusError as e:
self._handle_http_error(e)
response = self._request(method, url, **kwargs)
return response.json()
class AsyncArcade(AsyncArcadeClient):
@ -358,7 +361,7 @@ class AsyncArcade(AsyncArcadeClient):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.auth: AsyncAuthResource = AsyncAuthResource(self)
self.tool: AsyncToolResource = AsyncToolResource(self)
self.tools: AsyncToolResource = AsyncToolResource(self)
self.health: AsyncHealthResource = AsyncHealthResource(self)
chat_url = self._chat_url(self._base_url)
self._openai_client = AsyncOpenAI(base_url=chat_url, api_key=self._api_key)
@ -371,8 +374,5 @@ class AsyncArcade(AsyncArcadeClient):
"""
Execute an asynchronous request.
"""
try:
response = await self._request(method, url, **kwargs)
return response.json()
except httpx.HTTPStatusError as e:
self._handle_http_error(e)
response = await self._request(method, url, **kwargs)
return response.json()

View file

@ -1,9 +1,12 @@
import os
from enum import Enum
from pydantic import AnyUrl, BaseModel, Field
from arcade.core.schema import ToolAuthorizationContext, ToolCallOutput
OPENAI_API_VERSION = os.getenv("OPENAI_API_VERSION", "v1")
class AuthProvider(str, Enum):
"""The supported authorization providers."""

View file

@ -48,8 +48,6 @@ from arcade.core.utils import (
from arcade.sdk.annotations import Inferrable
from arcade.sdk.auth import BaseOAuth2, ToolAuthorization
DEFAULT_TOOLKIT_NAME = "Tools"
InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
WireType = Union[InnerWireType, Literal["array"]]
@ -116,7 +114,7 @@ class ToolCatalog(BaseModel):
def add_tool(
self,
tool_func: Callable,
toolkit_or_name: Union[str | None, Toolkit] = None,
toolkit_or_name: Union[str, Toolkit],
module: ModuleType | None = None,
) -> None:
"""
@ -131,9 +129,6 @@ class ToolCatalog(BaseModel):
elif isinstance(toolkit_or_name, str):
toolkit = None
toolkit_name = toolkit_or_name
else:
toolkit = None
toolkit_name = DEFAULT_TOOLKIT_NAME
if not toolkit_name:
raise ValueError("A toolkit name or toolkit must be provided.")
@ -163,6 +158,13 @@ class ToolCatalog(BaseModel):
output_model=output_model,
)
def add_module(self, module: ModuleType) -> None:
"""
Add all the tools in a module to the catalog.
"""
toolkit = Toolkit.from_module(module)
self.add_toolkit(toolkit)
def add_toolkit(self, toolkit: Toolkit) -> None:
"""
Add the tools from a loaded toolkit to the catalog.
@ -201,6 +203,15 @@ class ToolCatalog(BaseModel):
def get_tool_names(self) -> list[FullyQualifiedName]:
return [tool.definition.get_fully_qualified_name() for tool in self._tools.values()]
def find_tool_by_func(self, func: Callable) -> ToolDefinition:
"""
Find a tool by its function.
"""
for _, tool in self._tools.items():
if tool.tool == func:
return tool.definition
raise ValueError(f"Tool {func} not found in the catalog.")
def get_tool(self, name: FullyQualifiedName) -> MaterializedTool:
"""
Get a tool from the catalog by fully-qualified name and version.

View file

@ -1,16 +1,19 @@
import ipaddress
import os
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import idna
import toml
from pydantic import BaseModel, ValidationError
from arcade.core.env import settings
from pydantic import BaseModel, ConfigDict, ValidationError
class ApiConfig(BaseModel):
class BaseConfig(BaseModel):
model_config = ConfigDict(extra="ignore")
class ApiConfig(BaseConfig):
"""
Arcade API configuration.
"""
@ -19,9 +22,13 @@ class ApiConfig(BaseModel):
"""
Arcade API key.
"""
version: str = "v1"
"""
Arcade API version.
"""
class UserConfig(BaseModel):
class UserConfig(BaseConfig):
"""
Arcade user configuration.
"""
@ -32,7 +39,7 @@ class UserConfig(BaseModel):
"""
class EngineConfig(BaseModel):
class EngineConfig(BaseConfig):
"""
Arcade Engine configuration.
"""
@ -51,7 +58,7 @@ class EngineConfig(BaseModel):
"""
class Config(BaseModel):
class Config(BaseConfig):
"""
Configuration for Arcade.
"""
@ -79,7 +86,8 @@ class Config(BaseModel):
"""
Get the path to the Arcade configuration directory.
"""
return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
config_path = os.getenv("ARCADE_WORK_DIR") or Path.home() / ".arcade"
return Path(config_path).resolve()
@classmethod
def get_config_file_path(cls) -> Path:
@ -167,14 +175,14 @@ class Config(BaseModel):
if ":" in parsed_host.netloc and not is_ip:
host, existing_port = parsed_host.netloc.rsplit(":", 1)
if existing_port.isdigit():
return f"{protocol}://{parsed_host.netloc}/v1"
return f"{protocol}://{parsed_host.netloc}/{self.api.version}"
if is_fqdn and self.engine.port is None:
return f"{protocol}://{encoded_host}/v1"
return f"{protocol}://{encoded_host}/{self.api.version}"
elif self.engine.port is not None:
return f"{protocol}://{encoded_host}:{self.engine.port}/v1"
return f"{protocol}://{encoded_host}:{self.engine.port}/{self.api.version}"
else:
return f"{protocol}://{encoded_host}/v1"
return f"{protocol}://{encoded_host}/{self.api.version}"
@classmethod
def ensure_config_dir_exists(cls) -> None:

View file

@ -1,20 +0,0 @@
from functools import lru_cache
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env")
WORK_DIR: Path = Path.home() / ".arcade"
@lru_cache
def get_settings() -> Settings:
# env_file = os.getenv("ARCADE_ENV_FILE")
# TODO allow env override
return Settings()
settings = get_settings()

View file

@ -1,9 +1,11 @@
import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union
from pydantic import AnyUrl, BaseModel, Field
TOOL_NAME_SEPARATOR = "."
# allow for custom tool name separator
TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".")
class ValueSchema(BaseModel):

View file

@ -108,14 +108,19 @@ class Toolkit(BaseModel):
@classmethod
def find_all_arcade_toolkits(cls) -> list["Toolkit"]:
"""
Find all installed packages prefixed with 'arcade_' and load them as Toolkits.
Find all installed packages prefixed with 'arcade_' in the current
Python interpreter's environment and load them as Toolkits.
Returns:
List[Toolkit]: A list of Toolkit instances.
"""
import sysconfig
# Get the site-packages directory of the current interpreter
site_packages_dir = sysconfig.get_paths()["purelib"]
arcade_packages = [
dist.metadata["Name"]
for dist in importlib.metadata.distributions()
for dist in importlib.metadata.distributions(path=[site_packages_dir])
if dist.metadata["Name"].startswith("arcade_")
]
return [cls.from_package(package) for package in arcade_packages]

View file

@ -1,21 +1,5 @@
from .eval import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
NumericCritic,
SimilarityCritic,
tool_eval,
)
from .tool import tool
__all__ = [
"tool",
"EvalRubric",
"EvalSuite",
"ExpectedToolCall",
"tool_eval",
"BinaryCritic",
"SimilarityCritic",
"NumericCritic",
]

View file

@ -4,6 +4,7 @@ import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from arcade.core.config_model import Config
from arcade.core.schema import FullyQualifiedName
try:
@ -11,9 +12,10 @@ try:
from scipy.optimize import linear_sum_assignment
except ImportError:
raise ImportError(
"Use `pip install arcade[evals]` to install the required dependencies for evaluation."
"Use `pip install arcade-ai[evals]` to install the required dependencies for evaluation."
)
from arcade.client.client import Arcade, AsyncArcade
from arcade.sdk.error import WeightError
@ -69,12 +71,15 @@ class EvaluationResult:
passed: Whether the evaluation passed based on the fail_threshold.
warning: Whether the evaluation issued a warning based on the warn_threshold.
results: A list of dictionaries containing the results for each critic.
failure_reason: If the evaluation failed completely due to settings in the rubric,
this field contains the reason for failure.
"""
score: float = 0.0
passed: bool = False
warning: bool = False
results: list[dict[str, Any]] = field(default_factory=list)
failure_reason: str | None = None
@property
def fail(self) -> bool:
@ -120,10 +125,10 @@ class EvaluationResult:
Returns:
The score for the tool selection.
"""
score = weight if expected == actual else 0.0
score = weight if compare_tool_name(expected, actual) else 0.0
self.add(
"tool_selection",
{"match": expected == actual, "score": score},
{"match": compare_tool_name(expected, actual), "score": score},
weight,
expected,
actual,
@ -190,7 +195,10 @@ class EvalCase:
True if tool selection failure should occur, False otherwise.
"""
expected_tools = [tc.name for tc in self.expected_tool_calls]
return self.rubric.fail_on_tool_selection and set(expected_tools) != set(actual_tools)
return self.rubric.fail_on_tool_selection and not all(
compare_tool_name(expected, actual)
for expected, actual in zip(expected_tools, actual_tools)
)
def check_tool_call_quantity_failure(self, actual_count: int) -> bool:
"""
@ -218,17 +226,30 @@ class EvalCase:
evaluation_result = EvaluationResult()
actual_tools = [tool for tool, _ in actual_tool_calls]
if self.check_tool_selection_failure(actual_tools):
evaluation_result.score = 0.0
evaluation_result.passed = False
evaluation_result.warning = False
return evaluation_result
actual_count = len(actual_tool_calls)
if self.check_tool_call_quantity_failure(actual_count):
evaluation_result.score = 0.0
evaluation_result.passed = False
evaluation_result.warning = False
expected_count = len(self.expected_tool_calls)
evaluation_result.failure_reason = (
f"Expected {expected_count} tool call(s), but got {actual_count}"
)
return evaluation_result
# check if no tools should be called and none were called
if not self.expected_tool_calls and not actual_tools:
evaluation_result.score = 1.0
evaluation_result.passed = True
evaluation_result.warning = False
return evaluation_result
if self.check_tool_selection_failure(actual_tools):
evaluation_result.score = 0.0
evaluation_result.passed = False
evaluation_result.warning = False
expected_tools = [tc.name for tc in self.expected_tool_calls]
evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}"
return evaluation_result
# Create a cost matrix for the assignment problem
@ -422,12 +443,10 @@ class EvalSuite:
max_concurrent: int = 1 # Default to sequential execution
_client: AsyncArcade | Arcade | None = None
def initialize_client(self) -> None:
def initialize_client(self, config: Config) -> None:
"""
Initialize the client instance for the EvalSuite.
"""
from arcade.core.config import config
if self.max_concurrent > 1:
self._client = AsyncArcade(
api_key=config.api.key,
@ -443,7 +462,7 @@ class EvalSuite:
self,
name: str,
user_message: str,
expected_tool_calls: list[ExpectedToolCall],
expected_tool_calls: list[tuple[Callable, dict[str, Any]]],
critics: list["Critic"],
system_message: str | None = None,
rubric: EvalRubric | None = None,
@ -461,11 +480,18 @@ class EvalSuite:
rubric: The evaluation rubric for this case.
additional_messages: Optional list of additional messages for context.
"""
expected = [
ExpectedToolCall(
name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()),
args=args,
)
for func, args in expected_tool_calls
]
case = EvalCase(
name=name,
system_message=system_message or self.system_message,
user_message=user_message,
expected_tool_calls=expected_tool_calls,
expected_tool_calls=expected,
rubric=rubric or self.rubric,
critics=critics,
additional_messages=additional_messages or [],
@ -477,7 +503,7 @@ class EvalSuite:
name: str,
user_message: str,
system_message: str | None = None,
expected_tool_calls: list[ExpectedToolCall] | None = None,
expected_tool_calls: list[tuple[Callable, dict[str, Any]]] | None = None,
rubric: EvalRubric | None = None,
critics: list["Critic"] | None = None,
additional_messages: list[dict[str, str]] | None = None,
@ -507,12 +533,22 @@ class EvalSuite:
if additional_messages:
new_additional_messages.extend(additional_messages)
expected = last_case.expected_tool_calls
if expected_tool_calls:
expected = [
ExpectedToolCall(
name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()),
args=args,
)
for func, args in expected_tool_calls
]
# Create a new case, copying from the last one and updating fields
new_case = EvalCase(
name=name,
system_message=system_message or last_case.system_message,
user_message=user_message,
expected_tool_calls=expected_tool_calls or last_case.expected_tool_calls,
expected_tool_calls=expected,
rubric=rubric or self.rubric,
critics=critics or last_case.critics.copy(),
additional_messages=new_additional_messages,
@ -570,7 +606,7 @@ class EvalSuite:
return results
def run(self, model: str) -> dict[str, Any]:
def run(self, config: Config, model: str) -> dict[str, Any]:
"""
Run the evaluation suite.
@ -581,7 +617,7 @@ class EvalSuite:
A dictionary containing the evaluation results.
"""
if not self._client:
self.initialize_client()
self.initialize_client(config)
if self.max_concurrent > 1:
# Run asynchronously with concurrency
@ -614,10 +650,26 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
return tool_args_list
def compare_tool_name(expected: str, actual: str) -> bool:
"""
Compare the tool name without penalizing for mismatch in separators
between module names and tool names ex. '-' vs '_' vs '.' vs ' '
"""
# TODO optimize this
# Remove all separators from both names
separators = "-_."
expected_clean = "".join(char for char in expected if char not in separators)
actual_clean = "".join(char for char in actual if char not in separators)
# Compare the cleaned names
return expected_clean == actual_clean
def tool_eval() -> Callable[[Callable], Callable]:
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(
config: Config,
models: list[str],
max_concurrency: int = 1,
) -> list[dict[str, Any]]:
@ -627,7 +679,7 @@ def tool_eval() -> Callable[[Callable], Callable]:
suite.max_concurrent = max_concurrency
results = []
for model in models:
result = suite.run(model)
result = suite.run(config, model)
results.append(result)
return results

View file

@ -15,14 +15,13 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
pydantic = "^2.7.0"
pydantic-settings = "^2.2.1"
typer = "^0.9.0"
rich = "^13.7.1"
toml = "^0.10.2"
tomlkit = "^0.12.4"
requests = "^2.26.0" # TODO: is this really needed?
openai = "^1.36.0" # TODO: relax to an earlier version that still has what we need
pyjwt = "^2.8.0"
loguru = "^0.7.0"
[tool.poetry.group.fastapi.dependencies]
@ -115,7 +114,9 @@ ignore = [ # TODO work to remove these
# raise from (cli specific)
"B904",
# long message exceptions
"TRY003"
"TRY003",
# subprocess.Popen
"S603",
]
[tool.ruff.format]

View file

@ -68,6 +68,18 @@ HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA = {
}
@pytest.fixture
def test_sync_client():
"""Test client."""
return Arcade(base_url="http://arcade.example.com", api_key="fake_api_key")
@pytest.fixture
def test_async_client():
"""Test client."""
return AsyncArcade(base_url="http://arcade.example.com", api_key="fake_api_key")
@pytest.fixture
def mock_response():
"""Mock Response object for testing."""
@ -94,7 +106,7 @@ def mock_async_response():
(500, InternalServerError),
],
)
def test_handle_http_error(error_code, expected_error, mock_response):
def test_handle_http_error(test_sync_client, error_code, expected_error, mock_response):
"""Test _handle_http_error method for different error codes."""
mock_response.status_code = error_code
mock_response.json.return_value = {"error": "Test error message"}
@ -103,16 +115,14 @@ def test_handle_http_error(error_code, expected_error, mock_response):
mock_http_error = Mock(spec=HTTPStatusError)
mock_http_error.response = mock_response
client = Arcade(api_key="fake_api_key") # Create an instance of Arcade
with pytest.raises(expected_error):
client._handle_http_error(mock_http_error) # Call the method on the instance
test_sync_client._handle_http_error(mock_http_error) # Call the method on the instance
def test_arcade_auth_authorize(mock_response, monkeypatch):
def test_arcade_auth_authorize(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.auth.authorize method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
client = Arcade(api_key="fake_api_key")
auth_response = client.auth.authorize(
auth_response = test_sync_client.auth.authorize(
provider=AuthProvider.google,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
user_id="sam@arcade-ai.com",
@ -120,19 +130,17 @@ def test_arcade_auth_authorize(mock_response, monkeypatch):
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
def test_arcade_auth_poll_authorization(mock_response, monkeypatch):
def test_arcade_auth_poll_authorization(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.auth.poll_authorization method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
client = Arcade(api_key="fake_api_key")
auth_response = client.auth.status("auth_123")
auth_response = test_sync_client.auth.status("auth_123")
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
def test_arcade_tool_run(mock_response, monkeypatch):
"""Test Arcade.tool.run method."""
def test_arcade_tool_run(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.tools.run method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_RESPONSE_DATA)
client = Arcade(api_key="fake_api_key")
tool_response = client.tool.run(
tool_response = test_sync_client.tools.run(
tool_name="GetEmails",
user_id="sam@arcade-ai.com",
tool_version="0.1.0",
@ -141,54 +149,51 @@ def test_arcade_tool_run(mock_response, monkeypatch):
assert tool_response == ExecuteToolResponse(**TOOL_RESPONSE_DATA)
def test_arcade_tool_get(mock_response, monkeypatch):
"""Test Arcade.tool.get method."""
def test_arcade_tool_get(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.tools.get method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: TOOL_DEFINITION_DATA)
client = Arcade(api_key="fake_api_key")
tool_definition = client.tool.get(director_id="default", tool_id="GetEmails")
tool_definition = test_sync_client.tools.get(director_id="default", tool_id="GetEmails")
assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA)
def test_arcade_tool_authorize(mock_response, monkeypatch):
"""Test Arcade.tool.authorize method."""
def test_arcade_tool_authorize(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.tools.authorize method."""
monkeypatch.setattr(
Arcade, "_execute_request", lambda *args, **kwargs: TOOL_AUTHORIZE_RESPONSE_DATA
)
client = Arcade(api_key="fake_api_key")
auth_response = client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com")
auth_response = test_sync_client.tools.authorize(
tool_name="GetEmails", user_id="sam@arcade-ai.com"
)
assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA)
def test_arcade_health_check(mock_response, monkeypatch):
def test_arcade_health_check(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.health.check method."""
monkeypatch.setattr(
Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_HEALTHY_RESPONSE_DATA
)
client = Arcade(api_key="fake_api_key")
client.health.check()
test_sync_client.health.check()
assert True # If no exception is raised, the test passes
def test_arcade_health_check_raises_error(mock_response, monkeypatch):
def test_arcade_health_check_raises_error(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.health.check method."""
monkeypatch.setattr(
Arcade, "_execute_request", lambda *args, **kwargs: HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA
)
client = Arcade(api_key="fake_api_key")
with pytest.raises(EngineNotHealthyError):
client.health.check()
test_sync_client.health.check()
@pytest.mark.asyncio
async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch):
async def test_async_arcade_auth_authorize(test_async_client, mock_async_response, monkeypatch):
"""Test AsyncArcade.auth.authorize method."""
async def mock_execute_request(*args, **kwargs):
return AUTH_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
auth_response = await client.auth.authorize(
auth_response = await test_async_client.auth.authorize(
provider=AuthProvider.google,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
user_id="sam@arcade-ai.com",
@ -197,28 +202,28 @@ async def test_async_arcade_auth_authorize(mock_async_response, monkeypatch):
@pytest.mark.asyncio
async def test_async_arcade_auth_poll_authorization(mock_async_response, monkeypatch):
async def test_async_arcade_auth_poll_authorization(
test_async_client, mock_async_response, monkeypatch
):
"""Test AsyncArcade.auth.poll_authorization method."""
async def mock_execute_request(*args, **kwargs):
return AUTH_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
auth_response = await client.auth.status("auth_123")
auth_response = await test_async_client.auth.status("auth_123")
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
@pytest.mark.asyncio
async def test_async_arcade_tool_run(mock_async_response, monkeypatch):
"""Test AsyncArcade.tool.run method."""
async def test_async_arcade_tool_run(test_async_client, mock_async_response, monkeypatch):
"""Test AsyncArcade.tools.run method."""
async def mock_execute_request(*args, **kwargs):
return TOOL_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
tool_response = await client.tool.run(
tool_response = await test_async_client.tools.run(
tool_name="GetEmails",
user_id="sam@arcade-ai.com",
tool_version="0.1.0",
@ -228,52 +233,52 @@ async def test_async_arcade_tool_run(mock_async_response, monkeypatch):
@pytest.mark.asyncio
async def test_async_arcade_tool_get(mock_async_response, monkeypatch):
"""Test AsyncArcade.tool.get method."""
async def test_async_arcade_tool_get(test_async_client, mock_async_response, monkeypatch):
"""Test AsyncArcade.tools.get method."""
async def mock_execute_request(*args, **kwargs):
return TOOL_DEFINITION_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
tool_definition = await client.tool.get(director_id="default", tool_id="GetEmails")
tool_definition = await test_async_client.tools.get(director_id="default", tool_id="GetEmails")
assert tool_definition == ToolDefinition(**TOOL_DEFINITION_DATA)
@pytest.mark.asyncio
async def test_async_arcade_tool_authorize(mock_async_response, monkeypatch):
"""Test AsyncArcade.tool.authorize method."""
async def test_async_arcade_tool_authorize(test_async_client, mock_async_response, monkeypatch):
"""Test AsyncArcade.tools.authorize method."""
async def mock_execute_request(*args, **kwargs):
return TOOL_AUTHORIZE_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
auth_response = await client.tool.authorize(tool_name="GetEmails", user_id="sam@arcade-ai.com")
auth_response = await test_async_client.tools.authorize(
tool_name="GetEmails", user_id="sam@arcade-ai.com"
)
assert auth_response == AuthResponse(**TOOL_AUTHORIZE_RESPONSE_DATA)
@pytest.mark.asyncio
async def test_async_arcade_health_check(mock_async_response, monkeypatch):
async def test_async_arcade_health_check(test_async_client, mock_async_response, monkeypatch):
"""Test AsyncArcade.health.check method."""
async def mock_execute_request(*args, **kwargs):
return HEALTH_CHECK_HEALTHY_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
await client.health.check()
await test_async_client.health.check()
assert True # If no exception is raised, the test passes
@pytest.mark.asyncio
async def test_async_arcade_health_check_raises_error(mock_async_response, monkeypatch):
async def test_async_arcade_health_check_raises_error(
test_async_client, mock_async_response, monkeypatch
):
"""Test AsyncArcade.health.check method."""
async def mock_execute_request(*args, **kwargs):
return HEALTH_CHECK_UNHEALTHY_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
client = AsyncArcade(api_key="fake_api_key")
with pytest.raises(EngineNotHealthyError):
await client.health.check()
await test_async_client.health.check()

View file

@ -14,10 +14,10 @@ def sample_tool() -> str:
return "Hello, world!"
def test_add_tool_with_no_toolkit():
def test_add_tool_with_empty_toolkit_name_raises():
catalog = ToolCatalog()
catalog.add_tool(sample_tool)
assert catalog.get_tool(FullyQualifiedName("SampleTool", "Tools", None)).tool == sample_tool
with pytest.raises(ValueError):
catalog.add_tool(sample_tool, "")
def test_add_tool_with_toolkit_name():

View file

@ -8,7 +8,7 @@ ARG HOST=0.0.0.0
# Set environment variables using the build arguments
ENV PORT=${PORT}
ENV HOST=${HOST}
ENV WORK_DIR=/app
ENV ARCADE_WORK_DIR=/app
# Install system dependencies
RUN apt-get update && apt-get install -y \
@ -45,8 +45,8 @@ WORKDIR /app/toolkits
# Install toolkits from the toolkits directory
RUN set -e; \
for toolkit in ./*; do \
echo "Installing toolkit $toolkit"; \
pip install $toolkit; \
echo "Installing toolkit $toolkit"; \
pip install $toolkit; \
done

View file

@ -1,68 +0,0 @@
import os
from google.oauth2.credentials import Credentials
from langchain_google_community import GmailToolkit
from langchain_google_community.gmail.utils import (
build_resource_service,
)
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
# Step 1: Install required packages
# Run the following in your terminal:
# %pip install -qU langchain-google-community[gmail]
# %pip install -qU langchain-openai
# %pip install -qU langgraph
#
# Step 2: Set environment variables for LangChain and OpenAI API keys
# Uncomment the following lines if you have the LangSmith API key
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangSmith API key: ")
#
# Step 3 (Option 1) Manually authenticate with Gmail by creating your own google app, credentials, and handling tokens and Oauth
# credentials = get_gmail_credentials(
# token_file="token.json",
# scopes=["https://mail.google.com/"],
# client_secrets_file="credentials.json",
# )
#
# ----------------- OR -----------------
# Step 3 (Option 2) Use the Arcade SDK to authenticate with Gmail
from arcade.client import Arcade, AuthProvider
client = Arcade(api_key=os.environ["ARCADE_API_KEY"])
challenge = client.auth.authorize(
provider=AuthProvider.google,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
user_id="example_user_id",
)
if challenge.status != "completed":
print(f"Please visit this URL to authorize: {challenge.auth_url}")
input("Press Enter after you've completed the authorization...")
challenge = client.auth.poll_authorization(challenge)
if challenge.status != "completed":
print("Authorization not completed. Please try again.")
exit(1)
creds = Credentials(challenge.context.token)
api_resource = build_resource_service(credentials=creds)
toolkit = GmailToolkit(api_resource=api_resource)
# Step 4: Get available tools
tools = toolkit.get_tools()
# Step 5: Initialize the LLM and create an agent
llm = ChatOpenAI(model="gpt-4o")
agent_executor = create_react_agent(llm, tools)
# Step 6: Draft an email using the agent
example_query = "Read my latest emails to me and summarize them."
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
event["messages"][-1].pretty_print()

View file

@ -0,0 +1,60 @@
import time # Import time for polling delays
from google.oauth2.credentials import Credentials
from langchain_google_community import GmailToolkit
from langchain_google_community.gmail.utils import (
build_resource_service,
)
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
# Step 1: Install required packages
# Run the following in your terminal:
# %pip install -qU langchain-google-community[gmail]
# %pip install -qU langchain-openai
# %pip install -qU langgraph
from arcade.client import Arcade, AuthProvider
client = Arcade()
# Start the authorization process for the tool "ListEmails"
auth_response = client.auth.authorize(
provider=AuthProvider.google,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
user_id="sam@arcade-ai.com",
)
# If authorization is not completed, prompt the user and poll for status
if auth_response.status != "completed":
print(
"Please complete the authorization challenge in your browser before continuing:"
)
print(auth_response.auth_url)
input("Press Enter to continue...")
# Poll for authorization status using the auth polling method
while auth_response.status != "completed":
# Wait before polling again to avoid spamming the server
time.sleep(4)
auth_response = client.auth.status(auth_response)
# Authorization is completed; proceed with obtaining credentials
creds = Credentials(auth_response.context.token)
api_resource = build_resource_service(credentials=creds)
toolkit = GmailToolkit(api_resource=api_resource)
# Step 4: Get available tools
tools = toolkit.get_tools()
# Step 5: Initialize the LLM and create an agent
llm = ChatOpenAI(model="gpt-4o")
agent_executor = create_react_agent(llm, tools)
# Step 6: Draft an email using the agent
example_query = "Read my latest emails to me and summarize them."
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
event["messages"][-1].pretty_print()

View file

@ -0,0 +1,63 @@
import json
import os
from typing import Any, TypedDict
from langgraph.checkpoint.memory import MemorySaver
from langgraph.errors import NodeInterrupt
from langgraph.graph import END, START, StateGraph
from arcade.client import Arcade
client = Arcade(api_key=os.environ["ARCADE_API_KEY"])
class State(TypedDict):
emails: Any
def step_1(state: State, config) -> State:
user_id = config["configurable"]["user_id"]
challenge = client.tools.authorize(
tool_name="ListEmails",
user_id=user_id,
)
if challenge.status != "completed":
raise NodeInterrupt(f"Please visit this URL to authorize: {challenge.auth_url}")
result = client.tools.run(
tool_name="ListEmails",
user_id=user_id,
tool_version="default",
inputs=json.dumps({"n_emails": 5}),
)
return {"emails": result}
builder = StateGraph(State)
builder.add_node("step_1", step_1)
builder.add_edge(START, "step_1")
builder.add_edge("step_1", END)
# Set up memory
memory = MemorySaver()
# Compile the graph with memory
graph = builder.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "2", "user_id": "sam@arcade-ai.com"}}
result = graph.invoke({"emails": None}, config=config)
state = graph.get_state({"configurable": {"thread_id": "2"}})
print("interrupted state\n----------")
print(state)
print("----------")
input()
result = graph.invoke({"emails": None}, config=config)
state = graph.get_state({"configurable": {"thread_id": "2"}})
print("final state\n----------")
print(state)
print("----------")
print("final result\n----------")
print(result)
print("----------")

View file

@ -2,7 +2,7 @@ import os
from modal import App, Image, asgi_app
os.environ["WORK_DIR"] = "/root"
os.environ["ARCADE_WORK_DIR"] = "/root"
# Define the FastAPI app
app = App("arcade-ai-actor")

View file

@ -6,7 +6,7 @@ authors = ["Sam Partee <sam@arcade-ai.com>", "Eric Gustin <eric@arcade-ai.com>"]
[tool.poetry.dependencies]
python = "^3.10"
arcade-ai = "*"
arcade-ai = "^0.1.0"
google-api-core = "2.19.1"
google-api-python-client = "2.137.0"
google-auth = "2.32.0"
@ -16,7 +16,7 @@ googleapis-common-protos = "1.63.2"
beautifulsoup4 = "^4.10.0"
[tool.poetry.dev-dependencies]
pytest = "^7.4.0"
pytest = "^8.3.0"
[build-system]
requires = ["poetry-core>=1.0.0"]

View file

@ -1,12 +1,11 @@
from arcade.core.catalog import ToolCatalog
from arcade.core.toolkit import Toolkit
import arcade_math
from arcade_math.tools.arithmetic import add, sqrt
from arcade.core.catalog import ToolCatalog
from arcade.sdk.eval import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
tool_eval,
)
@ -18,11 +17,11 @@ rubric = EvalRubric(
catalog = ToolCatalog()
catalog.add_toolkit(Toolkit.from_module(arcade_math))
catalog.add_module(arcade_math)
@tool_eval()
def arithmetic_eval_suite():
def math_eval_suite():
suite = EvalSuite(
name="Math Tools Evaluation",
system_message="You are an AI assistant with access to math tools. Use them to help the user with their math-related tasks.",
@ -34,9 +33,9 @@ def arithmetic_eval_suite():
name="Add two large numbers",
user_message="Add 12345 and 987654321",
expected_tool_calls=[
ExpectedToolCall(
"Arithmetic_Add",
args={
(
add,
{
"a": 12345,
"b": 987654321,
},
@ -55,7 +54,12 @@ def arithmetic_eval_suite():
name="Take the square root of a large number",
user_message="What is the square root of 3224990521?",
expected_tool_calls=[
ExpectedToolCall("Arithmetic_Sqrt", args={"a": 3224990521})
(
sqrt,
{
"a": 3224990521,
},
)
],
rubric=rubric,
critics=[

View file

@ -7,10 +7,10 @@ authors = ["Nate <nate@arcade-ai.com>"]
[tool.poetry.dependencies]
python = "^3.10"
arcade-ai = "*"
arcade-ai = "^0.1.0"
[tool.poetry.dev-dependencies]
pytest = "^7.4"
pytest = "^8.3.0"
[build-system]
requires = ["poetry-core>=1.0.0"]

View file

@ -0,0 +1,239 @@
import arcade_search
from arcade_search.tools.google import search_google
from arcade.core.catalog import ToolCatalog
from arcade.sdk.eval import (
EvalRubric,
EvalSuite,
NumericCritic,
SimilarityCritic,
tool_eval,
)
# Evaluation rubric
rubric = EvalRubric(
fail_threshold=0.8,
warn_threshold=0.9,
)
catalog = ToolCatalog()
# Register the Google Search tool
catalog.add_module(arcade_search)
@tool_eval()
def google_search_eval_suite() -> EvalSuite:
"""Create an evaluation suite for the Google Search tool."""
suite = EvalSuite(
name="Google Search Tool Evaluation",
system_message="You are an AI assistant that can perform web searches using the provided tools.",
catalog=catalog,
rubric=rubric,
)
# Simple search query with default results
suite.add_case(
name="Simple search query with default results",
user_message="Search for 'Climate change effects on polar bears' on Google.",
expected_tool_calls=[
(
search_google,
{
"query": "Climate change effects on polar bears",
"n_results": 5,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=1.0),
],
)
# Search query with specific number of results
suite.add_case(
name="Search query with specific number of results",
user_message="Find the top 3 articles about quantum computing.",
expected_tool_calls=[
(
search_google,
{
"query": "articles about quantum computing",
"n_results": 3,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=0.7),
NumericCritic(
critic_field="n_results",
weight=0.3,
value_range=(1, 100),
),
],
)
# Search query with 'n' results specified in words
suite.add_case(
name="Search query with 'n' results specified in words",
user_message="Give me five recipes for vegan lasagna.",
expected_tool_calls=[
(
search_google,
{
"query": "recipes for vegan lasagna",
"n_results": 5,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=0.7),
NumericCritic(
critic_field="n_results",
weight=0.3,
value_range=(1, 100),
),
],
)
# Ambiguous number of results
suite.add_case(
name="Ambiguous number of results",
user_message="Find articles about climate change impacts 10.",
expected_tool_calls=[
(
search_google,
{
"query": "articles about climate change impacts 10",
"n_results": 5,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=1.0),
],
)
# Search query with multiple instructions
suite.add_case(
name="Search query with multiple instructions",
user_message="Search for the latest news on electric cars, and tell me about Tesla's new model.",
expected_tool_calls=[
(
search_google,
{
"query": "latest news on electric cars",
"n_results": 5,
},
),
(
search_google,
{
"query": "Tesla's new model",
"n_results": 5,
},
),
],
critics=[
SimilarityCritic(critic_field="query", weight=1.0),
],
)
# Search with stop words and filler words
suite.add_case(
name="Search with stop words and filler words",
user_message="Could you please search for the best ways to learn French?",
expected_tool_calls=[
(
search_google,
{
"query": "best ways to learn French",
"n_results": 5,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=1.0),
],
)
# No clear query given
suite.add_case(
name="No clear query given",
user_message="Find it for me.",
expected_tool_calls=[],
critics=[],
)
# Search query with special characters
suite.add_case(
name="Search query with special characters",
user_message="Find me '@OpenAI's latest research papers'",
expected_tool_calls=[
(
search_google,
{
"query": "@OpenAI's latest research papers",
"n_results": 5,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=1.0),
],
)
# Search query with complex instructions
suite.add_case(
name="Search query with complex instructions",
user_message="I need information about the impact of deforestation in the Amazon over the past decade.",
expected_tool_calls=[
(
search_google,
{
"query": "impact of deforestation in the Amazon over the past decade",
"n_results": 5,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=1.0),
],
)
# Search query in a different language
suite.add_case(
name="Search query in a different language",
user_message="Busca información sobre la economía de España.",
expected_tool_calls=[
(
search_google,
{
"query": "economía de España",
"n_results": 5,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=1.0),
],
)
# Search query with numeric data
suite.add_case(
name="Search query with numeric data",
user_message="What was the population of Japan in 2020?",
expected_tool_calls=[
(
search_google,
{
"query": "population of Japan in 2020",
"n_results": 5,
},
)
],
critics=[
SimilarityCritic(critic_field="query", weight=1.0),
],
)
return suite

View file

@ -6,11 +6,11 @@ authors = ["Sam Partee <sam@arcade-ai.com>"]
[tool.poetry.dependencies]
python = "^3.10"
arcade-ai = "*"
arcade-ai = "^0.1.0"
serpapi = "^0.1.5"
[tool.poetry.dev-dependencies]
pytest = "^7.4.0"
pytest = "^8.3.0"
[build-system]
requires = ["poetry-core>=1.0.0"]

View file

@ -1,3 +1,4 @@
import arcade_slack
from arcade_slack.tools.chat import send_dm_to_user, send_message_to_channel
from arcade.core.catalog import ToolCatalog
@ -5,7 +6,6 @@ from arcade.sdk.eval import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
SimilarityCritic,
tool_eval,
)
@ -19,8 +19,7 @@ rubric = EvalRubric(
catalog = ToolCatalog()
# Register the Slack tools
catalog.add_tool(send_dm_to_user)
catalog.add_tool(send_message_to_channel)
catalog.add_module(arcade_slack)
@tool_eval()
@ -38,9 +37,9 @@ def slack_eval_suite() -> EvalSuite:
name="Send DM to user with clear username",
user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'",
expected_tool_calls=[
ExpectedToolCall(
name="SendDmToUser",
args={
(
send_dm_to_user,
{
"user_name": "johndoe",
"message": "Hello, can we meet at 3 PM?",
},
@ -56,54 +55,54 @@ def slack_eval_suite() -> EvalSuite:
name="Send DM with ambiguous username",
user_message="Message John about the project deadline",
expected_tool_calls=[
ExpectedToolCall(
name="SendDmToUser",
args={
(
send_dm_to_user,
{
"user_name": "john",
"message": "Hi John, I wanted to check about the project deadline. Can you provide an update?",
},
)
],
critics=[
SimilarityCritic(critic_field="user_name", weight=0.6),
SimilarityCritic(critic_field="message", weight=0.4),
],
)
suite.add_case(
name="Send DM with username in different format",
user_message="DM Jane.Doe to reschedule our meeting",
expected_tool_calls=[
ExpectedToolCall(
name="SendDmToUser",
args={
"user_name": "jane.doe",
"message": "Hi Jane, I need to reschedule our meeting. When are you available?",
},
)
],
critics=[
BinaryCritic(critic_field="user_name", weight=0.6),
SimilarityCritic(critic_field="message", weight=0.4),
],
)
suite.add_case(
name="Send DM with username in different format",
user_message="DM Jane.Doe to reschedule our meeting",
expected_tool_calls=[
(
send_dm_to_user,
{
"user_name": "jane.doe",
"message": "Hi Jane, I need to reschedule our meeting. When are you available?",
},
)
],
critics=[
BinaryCritic(critic_field="user_name", weight=0.5),
SimilarityCritic(critic_field="message", weight=0.5),
],
)
# Send Message to Channel Scenarios
suite.add_case(
name="Send message to channel with clear name",
user_message="Post 'The new feature is now live!' in the #announcements channel",
expected_tool_calls=[
ExpectedToolCall(
name="SendMessageToChannel",
args={
(
send_message_to_channel,
{
"channel_name": "announcements",
"message": "The new feature is now live!",
},
)
],
critics=[
BinaryCritic(critic_field="channel_name", weight=0.6),
SimilarityCritic(critic_field="message", weight=0.4),
BinaryCritic(critic_field="channel_name", weight=0.5),
SimilarityCritic(critic_field="message", weight=0.5),
],
)
@ -111,9 +110,9 @@ def slack_eval_suite() -> EvalSuite:
name="Send message to channel with ambiguous name",
user_message="Inform the engineering team about the upcoming maintenance in the general channel",
expected_tool_calls=[
ExpectedToolCall(
name="SendMessageToChannel",
args={
(
send_message_to_channel,
{
"channel_name": "engineering",
"message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.",
},
@ -130,9 +129,9 @@ def slack_eval_suite() -> EvalSuite:
name="Ambiguous between DM and channel message",
user_message="Send 'Great job on the presentation!' to the team",
expected_tool_calls=[
ExpectedToolCall(
name="SendMessageToChannel",
args={
(
send_message_to_channel,
{
"channel_name": "general",
"message": "Great job on the presentation!",
},
@ -149,25 +148,25 @@ def slack_eval_suite() -> EvalSuite:
name="Multiple recipients in DM request",
user_message="Send a DM to Alice and Bob about pushing the meeting tomorrow. I have to much work to do.",
expected_tool_calls=[
ExpectedToolCall(
name="SendDmToUser",
args={
(
send_dm_to_user,
{
"user_name": "alice",
"message": "Hi Alice, about our meeting tomorrow, let's reschedule? I am swamped with work.",
},
),
ExpectedToolCall(
name="SendDmToUser",
args={
(
send_dm_to_user,
{
"user_name": "bob",
"message": "Hi Bob, about our meeting tomorrow, let's reschedule? I am swamped with work.",
},
),
],
critics=[
SimilarityCritic(critic_field="user_name", weight=0.6),
SimilarityCritic(critic_field="user_name", weight=0.7),
SimilarityCritic(
critic_field="message", weight=0.4, similarity_threshold=0.7
critic_field="message", weight=0.3, similarity_threshold=0.6
),
],
)
@ -176,9 +175,9 @@ def slack_eval_suite() -> EvalSuite:
name="Channel name similar to username",
user_message="Post 'sounds great!' in john-project channel",
expected_tool_calls=[
ExpectedToolCall(
name="SendMessageToChannel",
args={
(
send_message_to_channel,
{
"channel_name": "john-project",
"message": "Sounds great!",
},

View file

@ -10,7 +10,7 @@ arcade-ai = "^0.1.0"
slack-sdk = "^3.31.0"
[tool.poetry.dev-dependencies]
pytest = "^7.4.0"
pytest = "^8.3.0"
[build-system]
requires = ["poetry-core>=1.0.0"]

View file

@ -1,17 +1,16 @@
import arcade_x
from arcade_x.tools.tweets import post_tweet
# TODO
# delete_tweet_by_id,
# search_recent_tweets_by_keywords,
# search_recent_tweets_by_username,
# from arcade_x.tools.users import lookup_single_user_by_username
from arcade.core.catalog import ToolCatalog
from arcade_x.tools.tweets import (
post_tweet,
delete_tweet_by_id,
# search_recent_tweets_by_query,
search_recent_tweets_by_username,
search_recent_tweets_by_keywords,
)
from arcade_x.tools.users import lookup_single_user_by_username
from arcade.sdk.eval import (
BinaryCritic,
EvalRubric,
EvalSuite,
ExpectedToolCall,
SimilarityCritic,
tool_eval,
)
@ -22,11 +21,8 @@ rubric = EvalRubric(
)
catalog = ToolCatalog()
catalog.add_tool(search_recent_tweets_by_keywords)
catalog.add_tool(lookup_single_user_by_username)
catalog.add_tool(post_tweet)
catalog.add_tool(delete_tweet_by_id)
catalog.add_tool(search_recent_tweets_by_username)
# Register the X tools
catalog.add_module(arcade_x)
@tool_eval()
@ -45,17 +41,18 @@ def x_eval_suite() -> EvalSuite:
name="Post a tweet",
user_message="Send out a tweet that says 'Hello World! Exciting stuff is happening over at Arcade AI!'",
expected_tool_calls=[
ExpectedToolCall(
name="PostTweet",
args={
(
post_tweet,
{
"tweet_text": "Hello World! Exciting stuff is happening over at Arcade AI!"
},
)
],
critics=[
BinaryCritic(
SimilarityCritic(
critic_field="tweet_text",
weight=1.0,
similarity_threshold=0.9,
),
],
)