Address alpha release tweaks and bugs (#62)
# Address Alpha Release Tweaks and Bugs This PR addresses several issues and tweaks identified during the alpha release: - **Ensure `~/.arcade` directory exists before writing the config file** In `arcade/cli/authn.py`, added code to create the `~/.arcade` directory if it doesn't exist. This prevents errors when writing the configuration file during the login process. - **Fix retry logic in process management** In `arcade/cli/launcher.py`, corrected an off-by-one error in the retry logic within the `_manage_processes` function. This ensures that the process management behaves as expected when retries are exhausted. - **Allow passing environment variables to the engine process** (technically this option isn't exposed yet) Updated the `start_servers`, `_manage_processes`, and `_start_process` functions in `arcade/cli/launcher.py` to accept an `engine_env` parameter. This allows custom environment variables to be set for the engine process. Also, set `GIN_MODE` to `"release"` by default. - **Handle cases with no critics in evaluations** Modified the `EvalCase` class in `arcade/sdk/eval/eval.py` to handle scenarios where no critics are provided. This avoids potential errors during the evaluation process when critics are absent. Should add a test for this. - **Adjust dependencies in `pyproject.toml`** - Moved `uvicorn` to be an optional dependency and included it in the `fastapi` extra. - Removed unnecessary development dependencies (`mkdocs`, `mkdocs-material`, `mkdocstrings`). - Ensured that `uvicorn` is updated to version `^0.30.0`. --------- Co-authored-by: Nate Barbettini <nate@arcade-ai.com>
This commit is contained in:
parent
6b7562f6a2
commit
7d9354b4b4
5 changed files with 58 additions and 22 deletions
|
|
@ -55,6 +55,11 @@ class LoginCallbackHandler(BaseHTTPRequestHandler):
|
|||
)
|
||||
return False
|
||||
|
||||
# ensure the ~/.arcade directory exists
|
||||
# TODO: this should use WORK_DIR from env if set
|
||||
if not os.path.exists(os.path.expanduser("~/.arcade")):
|
||||
os.makedirs(os.path.expanduser("~/.arcade"), exist_ok=True)
|
||||
|
||||
# TODO don't overwrite existing config
|
||||
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
|
||||
new_config = {"api": {"key": api_key}, "user": {"email": email}}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ def start_servers(
|
|||
host: str,
|
||||
port: int,
|
||||
engine_config: str | None,
|
||||
engine_env: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Start the actor and engine servers.
|
||||
|
|
@ -42,7 +43,7 @@ def start_servers(
|
|||
engine_cmd = _build_engine_command(engine_config)
|
||||
|
||||
# Start and manage the processes
|
||||
_manage_processes(actor_cmd, engine_cmd)
|
||||
_manage_processes(actor_cmd, engine_cmd, engine_env)
|
||||
|
||||
|
||||
def _validate_host(host: str) -> str:
|
||||
|
|
@ -177,13 +178,16 @@ def _build_engine_command(engine_config: str) -> list[str]:
|
|||
return cmd
|
||||
|
||||
|
||||
def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
||||
def _manage_processes(
|
||||
actor_cmd: list[str], engine_cmd: list[str], engine_env: dict[str, str] | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Manages the lifecycle of the actor and engine processes.
|
||||
|
||||
Args:
|
||||
actor_cmd: The command to start the actor server.
|
||||
engine_cmd: The command to start the engine.
|
||||
engine_env: Environment variables to set for the engine.
|
||||
"""
|
||||
actor_process: subprocess.Popen | None = None
|
||||
engine_process: subprocess.Popen | None = None
|
||||
|
|
@ -211,7 +215,7 @@ def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
|||
|
||||
# Start the engine
|
||||
console.print("Starting engine...", style="bold green")
|
||||
engine_process = _start_process("Engine", engine_cmd)
|
||||
engine_process = _start_process("Engine", engine_cmd, engine_env)
|
||||
|
||||
# Monitor processes
|
||||
_monitor_processes(actor_process, engine_process)
|
||||
|
|
@ -222,8 +226,8 @@ def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
|||
f"Processes exited. Retry {retry_count} of {max_retries}.", style="bold yellow"
|
||||
)
|
||||
|
||||
if retry_count > max_retries:
|
||||
console.print(f"❌ Exiting after {retry_count - 1} retries", style="bold red")
|
||||
if retry_count >= max_retries:
|
||||
console.print(f"❌ Exiting after {max_retries} retries", style="bold red")
|
||||
terminate_processes(exit_program=True)
|
||||
break # Exit the loop
|
||||
|
||||
|
|
@ -243,13 +247,16 @@ def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
def _start_process(name: str, cmd: list[str]) -> subprocess.Popen:
|
||||
def _start_process(
|
||||
name: str, cmd: list[str], env: dict[str, str] | None = None
|
||||
) -> subprocess.Popen:
|
||||
"""
|
||||
Starts a subprocess and begins streaming its output.
|
||||
|
||||
Args:
|
||||
name: Name of the process.
|
||||
cmd: Command to execute.
|
||||
env: Environment variables to set for the process.
|
||||
|
||||
Returns:
|
||||
The subprocess.Popen object.
|
||||
|
|
@ -257,9 +264,17 @@ def _start_process(name: str, cmd: list[str]) -> subprocess.Popen:
|
|||
Raises:
|
||||
RuntimeError: If the process fails to start.
|
||||
"""
|
||||
_env = os.environ.copy()
|
||||
if env:
|
||||
_env.update(env)
|
||||
|
||||
# TODO temporary fix for GIN_MODE
|
||||
_env["GIN_MODE"] = "release"
|
||||
|
||||
try:
|
||||
process = subprocess.Popen( # noqa: S603, RUF100
|
||||
cmd,
|
||||
env=_env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True,
|
||||
|
|
@ -302,6 +317,7 @@ def _monitor_processes(actor_process: subprocess.Popen, engine_process: subproce
|
|||
actor_process: The actor subprocess.
|
||||
engine_process: The engine subprocess.
|
||||
"""
|
||||
|
||||
while True:
|
||||
actor_status = actor_process.poll()
|
||||
engine_status = engine_process.poll()
|
||||
|
|
|
|||
|
|
@ -529,8 +529,9 @@ def up(
|
|||
Start both the actor and engine servers.
|
||||
"""
|
||||
try:
|
||||
# TODO: pass Engine env vars from here
|
||||
start_servers(host, port, engine_config)
|
||||
except Exception as e:
|
||||
error_message = f"❌ Failed to start servers: {escape(str(e))}"
|
||||
console.print(error_message, style="bold red")
|
||||
raise typer.Exit(code=1)
|
||||
typer.Exit(code=1)
|
||||
|
|
|
|||
|
|
@ -162,12 +162,16 @@ class EvalCase:
|
|||
system_message: str
|
||||
user_message: str
|
||||
expected_tool_calls: list[ExpectedToolCall]
|
||||
critics: list["Critic"]
|
||||
critics: list["Critic"] | None = None
|
||||
additional_messages: list[dict[str, str]] = field(default_factory=list)
|
||||
rubric: EvalRubric = field(default_factory=EvalRubric)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate_critics()
|
||||
if self.critics is not None:
|
||||
self._validate_critics()
|
||||
else:
|
||||
# if no critics are provided, set to empty list
|
||||
self.critics = []
|
||||
|
||||
def _validate_critics(self) -> None:
|
||||
"""
|
||||
|
|
@ -176,6 +180,9 @@ class EvalCase:
|
|||
Raises:
|
||||
WeightError: If the sum of critic weights exceeds 1.0.
|
||||
"""
|
||||
if not self.critics:
|
||||
return
|
||||
|
||||
total_weight = sum(critic.weight for critic in self.critics)
|
||||
if total_weight > 1.0:
|
||||
raise WeightError(f"Sum of critic weights must not exceed 1.0, got {total_weight}")
|
||||
|
|
@ -252,6 +259,15 @@ class EvalCase:
|
|||
evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}"
|
||||
return evaluation_result
|
||||
|
||||
# if no critics for tool call arguments, then return
|
||||
# passing score as only tool selection and quantity is checked
|
||||
if not self.critics or len(self.critics) == 0:
|
||||
evaluation_result.score = 1.0
|
||||
evaluation_result.passed = True
|
||||
evaluation_result.warning = False
|
||||
# TODO passing reason should be added
|
||||
return evaluation_result
|
||||
|
||||
# Create a cost matrix for the assignment problem
|
||||
cost_matrix = self._create_cost_matrix(actual_tool_calls)
|
||||
|
||||
|
|
@ -322,12 +338,13 @@ class EvalCase:
|
|||
if expected.name == actual_tool:
|
||||
score += self.rubric.tool_selection_weight
|
||||
|
||||
for critic in self.critics:
|
||||
expected_value = expected.args.get(critic.critic_field)
|
||||
actual_value = actual_args.get(critic.critic_field)
|
||||
if expected_value is not None and actual_value is not None:
|
||||
result = critic.evaluate(expected_value, actual_value)
|
||||
score += result["score"]
|
||||
if self.critics:
|
||||
for critic in self.critics:
|
||||
expected_value = expected.args.get(critic.critic_field)
|
||||
actual_value = actual_args.get(critic.critic_field)
|
||||
if expected_value is not None and actual_value is not None:
|
||||
result = critic.evaluate(expected_value, actual_value)
|
||||
score += result["score"]
|
||||
cost_matrix[i, j] = score
|
||||
|
||||
return cost_matrix
|
||||
|
|
@ -463,7 +480,7 @@ class EvalSuite:
|
|||
name: str,
|
||||
user_message: str,
|
||||
expected_tool_calls: list[tuple[Callable, dict[str, Any]]],
|
||||
critics: list["Critic"],
|
||||
critics: list["Critic"] | None = None,
|
||||
system_message: str | None = None,
|
||||
rubric: EvalRubric | None = None,
|
||||
additional_messages: list[dict[str, str]] | None = None,
|
||||
|
|
@ -550,7 +567,7 @@ class EvalSuite:
|
|||
user_message=user_message,
|
||||
expected_tool_calls=expected,
|
||||
rubric=rubric or self.rubric,
|
||||
critics=critics or last_case.critics.copy(),
|
||||
critics=critics or (last_case.critics.copy() if last_case.critics else None),
|
||||
additional_messages=new_additional_messages,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,12 +23,13 @@ openai = "^1.36.0" # TODO: relax to an earlier version that still has what we ne
|
|||
pyjwt = "^2.8.0"
|
||||
loguru = "^0.7.0"
|
||||
fastapi = {version = "^0.110.0", optional = true}
|
||||
uvicorn = {version = "^0.30.0", optional = true}
|
||||
scipy = {version = "^1.14.0", optional = true}
|
||||
numpy = {version = "^2.0.0", optional = true}
|
||||
scikit-learn = {version = "^1.5.0", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
fastapi = ["fastapi"]
|
||||
fastapi = ["fastapi", "uvicorn"]
|
||||
evals = ["scipy", "numpy", "scikit-learn"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
|
@ -39,10 +40,6 @@ pre-commit = "^3.4.0"
|
|||
tox = "^4.11.1"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
types-toml = "^0.10.8"
|
||||
uvicorn = "^0.22.0"
|
||||
mkdocs = ">=1.5.2"
|
||||
mkdocs-material = ">=9.3.0"
|
||||
mkdocstrings = {extras = ["python"], version = ">=0.23.1"}
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
|
|
|||
Loading…
Reference in a new issue