Fix Github.CountStargazers and add tests (#92)
## Problem I found a bug with `Github.CountStargazers` where a stargazer count of `0` was interpreted as a null result. In other words, 0 wasn't passed back to the Engine correctly. Separately, the tool function was also not authorized correctly. ## Fix - Don't use a falsy comparison when evaluating `result` inside the `ToolOutputFactory` - Add unit tests for `ToolOutputFactory` to give us confidence in the business logic - Added `ToolContext` to pass in the authorization token correctly. Before ``` User (nate@arcade-ai.com): how many stars does the ArcadeAI/Docs repo have on github? Assistant (gpt-4o): I successfully checked the repository, but unfortunately, I cannot provide the number of stars for the ArcadeAI/Docs repository. Please try checking directly on GitHub for the most accurate information. Called tool 'Github_CountStargazers' Parameters:{"owner":"ArcadeAI","name":"Docs"} 'Github_CountStargazers' tool returned:Github.CountStargazers called successfully ``` After ``` User (nate@arcade-ai.com): how many stars does the ArcadeAI/Docs repo have on github? Assistant (gpt-4o): The ArcadeAI/Docs repository on GitHub has 0 stars. Called tool 'Github_CountStargazers' Parameters:{"owner":"ArcadeAI","name":"Docs"} 'Github_CountStargazers' tool returned:0
This commit is contained in:
parent
844403906d
commit
56fc83bf3e
4 changed files with 90 additions and 7 deletions
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
|
|
@ -6,17 +6,23 @@
|
|||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": ["main:app", "--app-dir", "${workspaceFolder}/examples/fastapi/arcade_example_fastapi", "--port", "8002"],
|
||||
"args": [
|
||||
"main:app",
|
||||
"--app-dir",
|
||||
"${workspaceFolder}/examples/fastapi/arcade_example_fastapi",
|
||||
"--port",
|
||||
"8002"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
"cwd": "${workspaceFolder}/examples/fastapi/arcade_example_fastapi"
|
||||
},
|
||||
{
|
||||
"name": "Debug `arcade dev --no-auth`",
|
||||
"name": "Debug `arcade actorup --no-auth`",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/arcade/run_cli.py",
|
||||
"args": ["dev", "--no-auth"],
|
||||
"args": ["actorup", "--no-auth"],
|
||||
"console": "integratedTerminal",
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ class ToolOutputFactory:
|
|||
*,
|
||||
data: T | None = None,
|
||||
) -> ToolCallOutput:
|
||||
value = data.result if data and hasattr(data, "result") and data.result else ""
|
||||
|
||||
value = getattr(data, "result", "") if data else ""
|
||||
return ToolCallOutput(value=value)
|
||||
|
||||
def fail(self, *, message: str, developer_message: str | None = None) -> ToolCallOutput:
|
||||
|
|
|
|||
75
arcade/tests/core/test_output.py
Normal file
75
arcade/tests/core/test_output.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from arcade.core.output import ToolOutputFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_factory():
|
||||
return ToolOutputFactory()
|
||||
|
||||
|
||||
class SampleOutputModel(BaseModel):
|
||||
result: Any
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, expected_value",
|
||||
[
|
||||
(None, ""),
|
||||
("success", "success"),
|
||||
("", ""),
|
||||
(None, ""),
|
||||
(123, 123),
|
||||
(0, 0),
|
||||
(123.45, 123.45),
|
||||
(True, True),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
def test_success(output_factory, data, expected_value):
|
||||
data_obj = SampleOutputModel(result=data) if data is not None else None
|
||||
output = output_factory.success(data=data_obj)
|
||||
assert output.value == expected_value
|
||||
assert output.error is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message, developer_message",
|
||||
[
|
||||
("Error occurred", None),
|
||||
("Error occurred", "Detailed error message"),
|
||||
],
|
||||
)
|
||||
def test_fail(output_factory, message, developer_message):
|
||||
output = output_factory.fail(message=message, developer_message=developer_message)
|
||||
assert output.error is not None
|
||||
assert output.error.message == message
|
||||
assert output.error.developer_message == developer_message
|
||||
assert output.error.can_retry is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message, developer_message, additional_prompt_content, retry_after_ms",
|
||||
[
|
||||
("Retry error", None, None, None),
|
||||
("Retry error", "Retrying", "Please try again with this additional data: foobar", 1000),
|
||||
],
|
||||
)
|
||||
def test_fail_retry(
|
||||
output_factory, message, developer_message, additional_prompt_content, retry_after_ms
|
||||
):
|
||||
output = output_factory.fail_retry(
|
||||
message=message,
|
||||
developer_message=developer_message,
|
||||
additional_prompt_content=additional_prompt_content,
|
||||
retry_after_ms=retry_after_ms,
|
||||
)
|
||||
assert output.error is not None
|
||||
assert output.error.message == message
|
||||
assert output.error.developer_message == developer_message
|
||||
assert output.error.can_retry is True
|
||||
assert output.error.additional_prompt_content == additional_prompt_content
|
||||
assert output.error.retry_after_ms == retry_after_ms
|
||||
|
|
@ -26,6 +26,7 @@ from arcade_github.tools.utils import (
|
|||
# Example arcade chat usage: "How many stargazers does the <OWNER>/<REPO> repo have?"
|
||||
@tool(requires_auth=GitHub())
|
||||
async def count_stargazers(
|
||||
context: ToolContext,
|
||||
owner: Annotated[str, "The owner of the repository"],
|
||||
name: Annotated[str, "The name of the repository"],
|
||||
) -> Annotated[int, "The number of stargazers (stars) for the specified repository"]:
|
||||
|
|
@ -36,15 +37,17 @@ async def count_stargazers(
|
|||
```
|
||||
"""
|
||||
|
||||
headers = get_github_json_headers(context.authorization.token)
|
||||
|
||||
url = get_url("repo", owner=owner, repo=name)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, headers=headers)
|
||||
|
||||
handle_github_response(response, url)
|
||||
|
||||
data = response.json()
|
||||
stargazers_count = data.get("stargazers_count", 0)
|
||||
return f"The repository {owner}/{name} has {stargazers_count} stargazers."
|
||||
return stargazers_count
|
||||
|
||||
|
||||
# Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-organization-repositories
|
||||
|
|
|
|||
Loading…
Reference in a new issue