From 22f2422aff045211b21c29e50b2a035b87a74cfc Mon Sep 17 00:00:00 2001 From: Eric Gustin <34000337+EricGustin@users.noreply.github.com> Date: Wed, 8 Jan 2025 22:40:45 -0800 Subject: [PATCH] Add progress bar to Evals CLI (#185) Adds a progress bar to the arcade evals CLI command. Displays progress on the number of `@tool_eval` functions that have completed. --- arcade/arcade/cli/main.py | 13 +++++++++---- arcade/pyproject.toml | 1 + 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index 0ae81956..e92ca210 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -12,6 +12,7 @@ from openai import OpenAI, OpenAIError from rich.console import Console from rich.markup import escape from rich.text import Text +from tqdm import tqdm from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login from arcade.cli.constants import DEFAULT_CLOUD_HOST, DEFAULT_ENGINE_HOST, LOCALHOST @@ -310,7 +311,7 @@ def evals( "gpt-4o", "--models", "-m", - help="The models to use for evaluation (default: gpt-4o)", + help="The models to use for evaluation (default: gpt-4o). Use commas to separate multiple models.", ), host: str = typer.Option( LOCALHOST, @@ -401,10 +402,14 @@ def evals( ) tasks.append(task) - # TODO add a progress bar here + # Track progress and results as suite functions complete + with tqdm(total=len(tasks), desc="Evaluations Progress") as pbar: + results = [] + for f in asyncio.as_completed(tasks): + results.append(await f) + pbar.update(1) + # TODO error handling on each eval - # Wait for all suite functions to complete - results = await asyncio.gather(*tasks) all_evaluations.extend(results) display_eval_results(all_evaluations, show_details=show_details) diff --git a/arcade/pyproject.toml b/arcade/pyproject.toml index a93529ec..5383d304 100644 --- a/arcade/pyproject.toml +++ b/arcade/pyproject.toml @@ -26,6 +26,7 @@ openai = "^1.36.0" # TODO: relax to an earlier version that still has what we ne arcadepy = "~0.2.0" pyjwt = "^2.8.0" loguru = "^0.7.0" +tqdm = "^4.1.0" types-python-dateutil = "2.9.0.20241003" types-pytz = "2024.2.0.20241003" opentelemetry-instrumentation-fastapi = {version = "0.48b0", optional = true}