Address alpha release tweaks and bugs (#62)
# Address Alpha Release Tweaks and Bugs This PR addresses several issues and tweaks identified during the alpha release: - **Ensure `~/.arcade` directory exists before writing the config file** In `arcade/cli/authn.py`, added code to create the `~/.arcade` directory if it doesn't exist. This prevents errors when writing the configuration file during the login process. - **Fix retry logic in process management** In `arcade/cli/launcher.py`, corrected an off-by-one error in the retry logic within the `_manage_processes` function. This ensures that the process management behaves as expected when retries are exhausted. - **Allow passing environment variables to the engine process** (technically this option isn't exposed yet) Updated the `start_servers`, `_manage_processes`, and `_start_process` functions in `arcade/cli/launcher.py` to accept an `engine_env` parameter. This allows custom environment variables to be set for the engine process. Also, set `GIN_MODE` to `"release"` by default. - **Handle cases with no critics in evaluations** Modified the `EvalCase` class in `arcade/sdk/eval/eval.py` to handle scenarios where no critics are provided. This avoids potential errors during the evaluation process when critics are absent. Should add a test for this. - **Adjust dependencies in `pyproject.toml`** - Moved `uvicorn` to be an optional dependency and included it in the `fastapi` extra. - Removed unnecessary development dependencies (`mkdocs`, `mkdocs-material`, `mkdocstrings`). - Ensured that `uvicorn` is updated to version `^0.30.0`. --------- Co-authored-by: Nate Barbettini <nate@arcade-ai.com>
This commit is contained in:
parent
6b7562f6a2
commit
7d9354b4b4
5 changed files with 58 additions and 22 deletions
|
|
@ -55,6 +55,11 @@ class LoginCallbackHandler(BaseHTTPRequestHandler):
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# ensure the ~/.arcade directory exists
|
||||||
|
# TODO: this should use WORK_DIR from env if set
|
||||||
|
if not os.path.exists(os.path.expanduser("~/.arcade")):
|
||||||
|
os.makedirs(os.path.expanduser("~/.arcade"), exist_ok=True)
|
||||||
|
|
||||||
# TODO don't overwrite existing config
|
# TODO don't overwrite existing config
|
||||||
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
|
config_file_path = os.path.expanduser("~/.arcade/arcade.toml")
|
||||||
new_config = {"api": {"key": api_key}, "user": {"email": email}}
|
new_config = {"api": {"key": api_key}, "user": {"email": email}}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ def start_servers(
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
engine_config: str | None,
|
engine_config: str | None,
|
||||||
|
engine_env: dict[str, str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Start the actor and engine servers.
|
Start the actor and engine servers.
|
||||||
|
|
@ -42,7 +43,7 @@ def start_servers(
|
||||||
engine_cmd = _build_engine_command(engine_config)
|
engine_cmd = _build_engine_command(engine_config)
|
||||||
|
|
||||||
# Start and manage the processes
|
# Start and manage the processes
|
||||||
_manage_processes(actor_cmd, engine_cmd)
|
_manage_processes(actor_cmd, engine_cmd, engine_env)
|
||||||
|
|
||||||
|
|
||||||
def _validate_host(host: str) -> str:
|
def _validate_host(host: str) -> str:
|
||||||
|
|
@ -177,13 +178,16 @@ def _build_engine_command(engine_config: str) -> list[str]:
|
||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
def _manage_processes(
|
||||||
|
actor_cmd: list[str], engine_cmd: list[str], engine_env: dict[str, str] | None = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Manages the lifecycle of the actor and engine processes.
|
Manages the lifecycle of the actor and engine processes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actor_cmd: The command to start the actor server.
|
actor_cmd: The command to start the actor server.
|
||||||
engine_cmd: The command to start the engine.
|
engine_cmd: The command to start the engine.
|
||||||
|
engine_env: Environment variables to set for the engine.
|
||||||
"""
|
"""
|
||||||
actor_process: subprocess.Popen | None = None
|
actor_process: subprocess.Popen | None = None
|
||||||
engine_process: subprocess.Popen | None = None
|
engine_process: subprocess.Popen | None = None
|
||||||
|
|
@ -211,7 +215,7 @@ def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
||||||
|
|
||||||
# Start the engine
|
# Start the engine
|
||||||
console.print("Starting engine...", style="bold green")
|
console.print("Starting engine...", style="bold green")
|
||||||
engine_process = _start_process("Engine", engine_cmd)
|
engine_process = _start_process("Engine", engine_cmd, engine_env)
|
||||||
|
|
||||||
# Monitor processes
|
# Monitor processes
|
||||||
_monitor_processes(actor_process, engine_process)
|
_monitor_processes(actor_process, engine_process)
|
||||||
|
|
@ -222,8 +226,8 @@ def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
||||||
f"Processes exited. Retry {retry_count} of {max_retries}.", style="bold yellow"
|
f"Processes exited. Retry {retry_count} of {max_retries}.", style="bold yellow"
|
||||||
)
|
)
|
||||||
|
|
||||||
if retry_count > max_retries:
|
if retry_count >= max_retries:
|
||||||
console.print(f"❌ Exiting after {retry_count - 1} retries", style="bold red")
|
console.print(f"❌ Exiting after {max_retries} retries", style="bold red")
|
||||||
terminate_processes(exit_program=True)
|
terminate_processes(exit_program=True)
|
||||||
break # Exit the loop
|
break # Exit the loop
|
||||||
|
|
||||||
|
|
@ -243,13 +247,16 @@ def _manage_processes(actor_cmd: list[str], engine_cmd: list[str]) -> None:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def _start_process(name: str, cmd: list[str]) -> subprocess.Popen:
|
def _start_process(
|
||||||
|
name: str, cmd: list[str], env: dict[str, str] | None = None
|
||||||
|
) -> subprocess.Popen:
|
||||||
"""
|
"""
|
||||||
Starts a subprocess and begins streaming its output.
|
Starts a subprocess and begins streaming its output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of the process.
|
name: Name of the process.
|
||||||
cmd: Command to execute.
|
cmd: Command to execute.
|
||||||
|
env: Environment variables to set for the process.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The subprocess.Popen object.
|
The subprocess.Popen object.
|
||||||
|
|
@ -257,9 +264,17 @@ def _start_process(name: str, cmd: list[str]) -> subprocess.Popen:
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the process fails to start.
|
RuntimeError: If the process fails to start.
|
||||||
"""
|
"""
|
||||||
|
_env = os.environ.copy()
|
||||||
|
if env:
|
||||||
|
_env.update(env)
|
||||||
|
|
||||||
|
# TODO temporary fix for GIN_MODE
|
||||||
|
_env["GIN_MODE"] = "release"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process = subprocess.Popen( # noqa: S603, RUF100
|
process = subprocess.Popen( # noqa: S603, RUF100
|
||||||
cmd,
|
cmd,
|
||||||
|
env=_env,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
universal_newlines=True,
|
universal_newlines=True,
|
||||||
|
|
@ -302,6 +317,7 @@ def _monitor_processes(actor_process: subprocess.Popen, engine_process: subproce
|
||||||
actor_process: The actor subprocess.
|
actor_process: The actor subprocess.
|
||||||
engine_process: The engine subprocess.
|
engine_process: The engine subprocess.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
actor_status = actor_process.poll()
|
actor_status = actor_process.poll()
|
||||||
engine_status = engine_process.poll()
|
engine_status = engine_process.poll()
|
||||||
|
|
|
||||||
|
|
@ -529,8 +529,9 @@ def up(
|
||||||
Start both the actor and engine servers.
|
Start both the actor and engine servers.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# TODO: pass Engine env vars from here
|
||||||
start_servers(host, port, engine_config)
|
start_servers(host, port, engine_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"❌ Failed to start servers: {escape(str(e))}"
|
error_message = f"❌ Failed to start servers: {escape(str(e))}"
|
||||||
console.print(error_message, style="bold red")
|
console.print(error_message, style="bold red")
|
||||||
raise typer.Exit(code=1)
|
typer.Exit(code=1)
|
||||||
|
|
|
||||||
|
|
@ -162,12 +162,16 @@ class EvalCase:
|
||||||
system_message: str
|
system_message: str
|
||||||
user_message: str
|
user_message: str
|
||||||
expected_tool_calls: list[ExpectedToolCall]
|
expected_tool_calls: list[ExpectedToolCall]
|
||||||
critics: list["Critic"]
|
critics: list["Critic"] | None = None
|
||||||
additional_messages: list[dict[str, str]] = field(default_factory=list)
|
additional_messages: list[dict[str, str]] = field(default_factory=list)
|
||||||
rubric: EvalRubric = field(default_factory=EvalRubric)
|
rubric: EvalRubric = field(default_factory=EvalRubric)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self._validate_critics()
|
if self.critics is not None:
|
||||||
|
self._validate_critics()
|
||||||
|
else:
|
||||||
|
# if no critics are provided, set to empty list
|
||||||
|
self.critics = []
|
||||||
|
|
||||||
def _validate_critics(self) -> None:
|
def _validate_critics(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -176,6 +180,9 @@ class EvalCase:
|
||||||
Raises:
|
Raises:
|
||||||
WeightError: If the sum of critic weights exceeds 1.0.
|
WeightError: If the sum of critic weights exceeds 1.0.
|
||||||
"""
|
"""
|
||||||
|
if not self.critics:
|
||||||
|
return
|
||||||
|
|
||||||
total_weight = sum(critic.weight for critic in self.critics)
|
total_weight = sum(critic.weight for critic in self.critics)
|
||||||
if total_weight > 1.0:
|
if total_weight > 1.0:
|
||||||
raise WeightError(f"Sum of critic weights must not exceed 1.0, got {total_weight}")
|
raise WeightError(f"Sum of critic weights must not exceed 1.0, got {total_weight}")
|
||||||
|
|
@ -252,6 +259,15 @@ class EvalCase:
|
||||||
evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}"
|
evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}"
|
||||||
return evaluation_result
|
return evaluation_result
|
||||||
|
|
||||||
|
# if no critics for tool call arguments, then return
|
||||||
|
# passing score as only tool selection and quantity is checked
|
||||||
|
if not self.critics or len(self.critics) == 0:
|
||||||
|
evaluation_result.score = 1.0
|
||||||
|
evaluation_result.passed = True
|
||||||
|
evaluation_result.warning = False
|
||||||
|
# TODO passing reason should be added
|
||||||
|
return evaluation_result
|
||||||
|
|
||||||
# Create a cost matrix for the assignment problem
|
# Create a cost matrix for the assignment problem
|
||||||
cost_matrix = self._create_cost_matrix(actual_tool_calls)
|
cost_matrix = self._create_cost_matrix(actual_tool_calls)
|
||||||
|
|
||||||
|
|
@ -322,12 +338,13 @@ class EvalCase:
|
||||||
if expected.name == actual_tool:
|
if expected.name == actual_tool:
|
||||||
score += self.rubric.tool_selection_weight
|
score += self.rubric.tool_selection_weight
|
||||||
|
|
||||||
for critic in self.critics:
|
if self.critics:
|
||||||
expected_value = expected.args.get(critic.critic_field)
|
for critic in self.critics:
|
||||||
actual_value = actual_args.get(critic.critic_field)
|
expected_value = expected.args.get(critic.critic_field)
|
||||||
if expected_value is not None and actual_value is not None:
|
actual_value = actual_args.get(critic.critic_field)
|
||||||
result = critic.evaluate(expected_value, actual_value)
|
if expected_value is not None and actual_value is not None:
|
||||||
score += result["score"]
|
result = critic.evaluate(expected_value, actual_value)
|
||||||
|
score += result["score"]
|
||||||
cost_matrix[i, j] = score
|
cost_matrix[i, j] = score
|
||||||
|
|
||||||
return cost_matrix
|
return cost_matrix
|
||||||
|
|
@ -463,7 +480,7 @@ class EvalSuite:
|
||||||
name: str,
|
name: str,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
expected_tool_calls: list[tuple[Callable, dict[str, Any]]],
|
expected_tool_calls: list[tuple[Callable, dict[str, Any]]],
|
||||||
critics: list["Critic"],
|
critics: list["Critic"] | None = None,
|
||||||
system_message: str | None = None,
|
system_message: str | None = None,
|
||||||
rubric: EvalRubric | None = None,
|
rubric: EvalRubric | None = None,
|
||||||
additional_messages: list[dict[str, str]] | None = None,
|
additional_messages: list[dict[str, str]] | None = None,
|
||||||
|
|
@ -550,7 +567,7 @@ class EvalSuite:
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
expected_tool_calls=expected,
|
expected_tool_calls=expected,
|
||||||
rubric=rubric or self.rubric,
|
rubric=rubric or self.rubric,
|
||||||
critics=critics or last_case.critics.copy(),
|
critics=critics or (last_case.critics.copy() if last_case.critics else None),
|
||||||
additional_messages=new_additional_messages,
|
additional_messages=new_additional_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,12 +23,13 @@ openai = "^1.36.0" # TODO: relax to an earlier version that still has what we ne
|
||||||
pyjwt = "^2.8.0"
|
pyjwt = "^2.8.0"
|
||||||
loguru = "^0.7.0"
|
loguru = "^0.7.0"
|
||||||
fastapi = {version = "^0.110.0", optional = true}
|
fastapi = {version = "^0.110.0", optional = true}
|
||||||
|
uvicorn = {version = "^0.30.0", optional = true}
|
||||||
scipy = {version = "^1.14.0", optional = true}
|
scipy = {version = "^1.14.0", optional = true}
|
||||||
numpy = {version = "^2.0.0", optional = true}
|
numpy = {version = "^2.0.0", optional = true}
|
||||||
scikit-learn = {version = "^1.5.0", optional = true}
|
scikit-learn = {version = "^1.5.0", optional = true}
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
fastapi = ["fastapi"]
|
fastapi = ["fastapi", "uvicorn"]
|
||||||
evals = ["scipy", "numpy", "scikit-learn"]
|
evals = ["scipy", "numpy", "scikit-learn"]
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|
@ -39,10 +40,6 @@ pre-commit = "^3.4.0"
|
||||||
tox = "^4.11.1"
|
tox = "^4.11.1"
|
||||||
pytest-asyncio = "^0.23.7"
|
pytest-asyncio = "^0.23.7"
|
||||||
types-toml = "^0.10.8"
|
types-toml = "^0.10.8"
|
||||||
uvicorn = "^0.22.0"
|
|
||||||
mkdocs = ">=1.5.2"
|
|
||||||
mkdocs-material = ">=9.3.0"
|
|
||||||
mkdocstrings = {extras = ["python"], version = ">=0.23.1"}
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue