diff --git a/Makefile b/Makefile index 4854e506..3f4cb0c1 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,19 @@ check: ## Run code quality tools. @echo "🚀 Static type checking: Running mypy" @cd arcade && poetry run mypy $(git ls-files '*.py') + +.PHONY: check-toolkits +check-toolkits: ## Run code quality tools for each toolkit that has a Makefile + @echo "🚀 Running 'make check' in each toolkit with a Makefile" + @for dir in toolkits/*/ ; do \ + if [ -f "$$dir/Makefile" ]; then \ + echo "🛠️ Checking toolkit $$dir"; \ + (cd "$$dir" && make check); \ + else \ + echo "🛠️ Skipping toolkit $$dir (no Makefile found)"; \ + fi; \ + done + .PHONY: test test: ## Test the code with pytest @echo "🚀 Testing code: Running pytest" @@ -144,4 +157,6 @@ help: @echo "🛠️ Arcade AI Dev Commands:\n" @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + + .DEFAULT_GOAL := help diff --git a/toolkits/x/tests/test_x.py b/arcade/arcade/py.typed similarity index 100% rename from toolkits/x/tests/test_x.py rename to arcade/arcade/py.typed diff --git a/toolkits/x/.pre-commit-config.yaml b/toolkits/x/.pre-commit-config.yaml new file mode 100644 index 00000000..3953e996 --- /dev/null +++ b/toolkits/x/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^./ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/x/.ruff.toml b/toolkits/x/.ruff.toml new file mode 100644 index 00000000..36a7a4ed --- /dev/null +++ b/toolkits/x/.ruff.toml @@ -0,0 +1,44 @@ +target-version = "py39" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"**/tests/*" = ["S101"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/x/LICENSE b/toolkits/x/LICENSE new file mode 100644 index 00000000..52cd087e --- /dev/null +++ b/toolkits/x/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024, arcadeai + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/x/Makefile b/toolkits/x/Makefile new file mode 100644 index 00000000..8602ccaf --- /dev/null +++ b/toolkits/x/Makefile @@ -0,0 +1,53 @@ +.PHONY: help + +help: + @echo "🛠️ x Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the poetry environment and install the pre-commit hooks + @echo "📦 Checking if Poetry is installed" + @if ! command -v poetry &> /dev/null; then \ + echo "📦 Installing Poetry with pip"; \ + pip install poetry; \ + else \ + echo "📦 Poetry is already installed"; \ + fi + @echo "🚀 Installing package in development mode with all extras" + poetry install --all-extras + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + poetry build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + coverage report + @echo "Generating coverage report" + coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file + @echo "🚀 Bumping version in pyproject.toml" + poetry version patch + +.PHONY: check +check: ## Run code quality tools. + @echo "🚀 Checking Poetry lock file consistency with 'pyproject.toml': Running poetry check --lock" + @poetry check --lock + @echo "🚀 Linting code: Running pre-commit" + @poetry run pre-commit run -a + @echo "🚀 Static type checking: Running mypy" + @poetry run mypy --config-file=pyproject.toml diff --git a/toolkits/x/arcade_x/tools/tweets.py b/toolkits/x/arcade_x/tools/tweets.py index 93d8e149..cc2155b8 100644 --- a/toolkits/x/arcade_x/tools/tweets.py +++ b/toolkits/x/arcade_x/tools/tweets.py @@ -1,12 +1,13 @@ -from typing import Annotated +from typing import Annotated, Any import httpx - from arcade.sdk import ToolContext, tool from arcade.sdk.auth import X -from arcade.sdk.errors import ToolExecutionError +from arcade.sdk.errors import RetryableToolError + from arcade_x.tools.utils import ( expand_urls_in_tweets, + get_headers_with_token, get_tweet_url, parse_search_recent_tweets_response, ) @@ -14,7 +15,10 @@ from arcade_x.tools.utils import ( TWEETS_URL = "https://api.x.com/2/tweets" -# Manage Tweets Tools. See developer docs for additional available parameters: https://developer.x.com/en/docs/x-api/tweets/manage-tweets/api-reference +# Manage Tweets Tools. See developer docs for additional available parameters: +# https://developer.x.com/en/docs/x-api/tweets/manage-tweets/api-reference + + @tool( requires_auth=X( scopes=["tweet.read", "tweet.write", "users.read"], @@ -26,19 +30,12 @@ async def post_tweet( ) -> Annotated[str, "Success string and the URL of the tweet"]: """Post a tweet to X (Twitter).""" - headers = { - "Authorization": f"Bearer {context.authorization.token}", - "Content-Type": "application/json", - } + headers = get_headers_with_token(context) payload = {"text": tweet_text} async with httpx.AsyncClient() as client: response = await client.post(TWEETS_URL, headers=headers, json=payload, timeout=10) - - if response.status_code != 201: - raise ToolExecutionError( - f"Failed to post a tweet during execution of '{post_tweet.__name__}' tool. Request returned an error: {response.status_code} {response.text}" - ) + response.raise_for_status() tweet_id = response.json()["data"]["id"] return f"Tweet with id {tweet_id} posted successfully. URL: {get_tweet_url(tweet_id)}" @@ -51,16 +48,12 @@ async def delete_tweet_by_id( ) -> Annotated[str, "Success string confirming the tweet deletion"]: """Delete a tweet on X (Twitter).""" - headers = {"Authorization": f"Bearer {context.authorization.token}"} + headers = get_headers_with_token(context) url = f"{TWEETS_URL}/{tweet_id}" async with httpx.AsyncClient() as client: response = await client.delete(url, headers=headers, timeout=10) - - if response.status_code != 200: - raise ToolExecutionError( - f"Failed to delete the tweet during execution of '{delete_tweet_by_id.__name__}' tool. Request returned an error: {response.status_code} {response.text}" - ) + response.raise_for_status() return f"Tweet with id {tweet_id} deleted successfully." @@ -72,33 +65,33 @@ async def search_recent_tweets_by_username( max_results: Annotated[ int, "The maximum number of results to return. Cannot be less than 10" ] = 10, -) -> Annotated[dict, "Dictionary containing the search results"]: - """Search for recent tweets (last 7 days) on X (Twitter) by username. Includes replies and reposts.""" +) -> Annotated[dict[str, Any], "Dictionary containing the search results"]: + """Search for recent tweets (last 7 days) on X (Twitter) by username. + Includes replies and reposts.""" - headers = { - "Authorization": f"Bearer {context.authorization.token}", - "Content-Type": "application/json", - } - params = { + headers = get_headers_with_token(context) + params: dict[str, int | str] = { "query": f"from:{username}", "max_results": max(max_results, 10), # X API does not allow 'max_results' less than 10 } - url = "https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities" + url = ( + "https://api.x.com/2/tweets/search/recent?" + "expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities" + ) async with httpx.AsyncClient() as client: response = await client.get(url, headers=headers, params=params, timeout=10) + response.raise_for_status() - if response.status_code != 200: - raise ToolExecutionError( - f"Failed to search recent tweets during execution of '{search_recent_tweets_by_username.__name__}' tool. Request returned an error: {response.status_code} {response.text}" - ) + response_data: dict[str, Any] = response.json() - response_data = response.json() + # Expand the URLs that are in the tweets + response_data["data"] = expand_urls_in_tweets( + response_data.get("data", []), delete_entities=True + ) - # Expand the urls that are in the tweets - expand_urls_in_tweets(response_data.get("data", []), delete_entities=True) - - parse_search_recent_tweets_response(response_data) + # Parse the response data + response_data = parse_search_recent_tweets_response(response_data) return response_data @@ -115,44 +108,81 @@ async def search_recent_tweets_by_keywords( max_results: Annotated[ int, "The maximum number of results to return. Cannot be less than 10" ] = 10, -) -> Annotated[dict, "Dictionary containing the search results"]: +) -> Annotated[dict[str, Any], "Dictionary containing the search results"]: """ - Search for recent tweets (last 7 days) on X (Twitter) by required keywords and phrases. Includes replies and reposts - One of the following input parametersMUST be provided: keywords, phrases + Search for recent tweets (last 7 days) on X (Twitter) by required keywords and phrases. + Includes replies and reposts. + One of the following input parameters MUST be provided: keywords, phrases """ if not any([keywords, phrases]): - raise ValueError( - "At least one of keywords or phrases must be provided to the '{search_recent_tweets_by_keywords.__name__}' tool." + raise RetryableToolError( # noqa: TRY003 + "No keywords or phrases provided", + developer_message="Predicted inputs didn't contain any keywords or phrases", + additional_prompt_content="Please provide at least one keyword or phrase for search", + retry_after_ms=500, # Play nice with X API rate limits ) - headers = { - "Authorization": f"Bearer {context.authorization.token}", - "Content-Type": "application/json", - } + headers = get_headers_with_token(context) + query = "".join([f'"{phrase}" ' for phrase in (phrases or [])]) if keywords: query += " ".join(keywords or []) - params = { - "query": query, + params: dict[str, int | str] = { + "query": query.strip(), "max_results": max(max_results, 10), # X API does not allow 'max_results' less than 10 } - url = "https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities" + url = ( + "https://api.x.com/2/tweets/search/recent?" + "expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities" + ) async with httpx.AsyncClient() as client: response = await client.get(url, headers=headers, params=params, timeout=10) + response.raise_for_status() - if response.status_code != 200: - raise ToolExecutionError( - f"Failed to search recent tweets during execution of '{search_recent_tweets_by_keywords.__name__}' tool. Request returned an error: {response.status_code} {response.text}" - ) + response_data: dict[str, Any] = response.json() - response_data = response.json() + # Expand the URLs that are in the tweets + response_data["data"] = expand_urls_in_tweets( + response_data.get("data", []), delete_entities=True + ) - # Expand the urls that are in the tweets - expand_urls_in_tweets(response_data.get("data", []), delete_entities=True) - - parse_search_recent_tweets_response(response_data) + # Parse the response data + response_data = parse_search_recent_tweets_response(response_data) + + return response_data + + +@tool(requires_auth=X(scopes=["tweet.read", "users.read"])) +async def lookup_tweet_by_id( + context: ToolContext, + tweet_id: Annotated[str, "The ID of the tweet you want to look up"], +) -> Annotated[dict[str, Any], "Dictionary containing the tweet data"]: + """Look up a tweet on X (Twitter) by tweet ID.""" + + headers = get_headers_with_token(context) + params = { + "expansions": "author_id", + "user.fields": "id,name,username,entities", + "tweet.fields": "entities", + } + url = f"{TWEETS_URL}/{tweet_id}" + + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, params=params, timeout=10) + response.raise_for_status() + + response_data: dict[str, Any] = response.json() + + # Get the tweet data + tweet_data = response_data.get("data") + if tweet_data: + # Expand the URLs that are in the tweet + expanded_tweet_list = expand_urls_in_tweets([tweet_data], delete_entities=True) + response_data["data"] = expanded_tweet_list[0] + else: + response_data["data"] = {} return response_data diff --git a/toolkits/x/arcade_x/tools/users.py b/toolkits/x/arcade_x/tools/users.py index b235f046..82071d72 100644 --- a/toolkits/x/arcade_x/tools/users.py +++ b/toolkits/x/arcade_x/tools/users.py @@ -1,14 +1,20 @@ from typing import Annotated import httpx - from arcade.sdk import ToolContext, tool from arcade.sdk.auth import X -from arcade.sdk.errors import ToolExecutionError -from arcade_x.tools.utils import expand_urls_in_user_description, expand_urls_in_user_url +from arcade.sdk.errors import RetryableToolError + +from arcade_x.tools.utils import ( + expand_urls_in_user_description, + expand_urls_in_user_url, + get_headers_with_token, +) + +# Users Lookup Tools. See developer docs for additional available query parameters: +# https://developer.x.com/en/docs/x-api/users/lookup/api-reference -# Users Lookup Tools. See developer docs for additional available query parameters: https://developer.x.com/en/docs/x-api/users/lookup/api-reference @tool(requires_auth=X(scopes=["users.read", "tweet.read"])) async def lookup_single_user_by_username( context: ToolContext, @@ -16,49 +22,43 @@ async def lookup_single_user_by_username( ) -> Annotated[dict, "User information including id, name, username, and description"]: """Look up a user on X (Twitter) by their username.""" - headers = { - "Authorization": f"Bearer {context.authorization.token}", - } - url = f"https://api.x.com/2/users/by/username/{username}?user.fields=created_at,description,id,location,most_recent_tweet_id,name,pinned_tweet_id,profile_image_url,protected,public_metrics,url,username,verified,verified_type,withheld,entities" + headers = get_headers_with_token(context) + + user_fields = ",".join([ + "created_at", + "description", + "id", + "location", + "most_recent_tweet_id", + "name", + "pinned_tweet_id", + "profile_image_url", + "protected", + "public_metrics", + "url", + "username", + "verified", + "verified_type", + "withheld", + "entities", + ]) + url = f"https://api.x.com/2/users/by/username/{username}?user.fields={user_fields}" async with httpx.AsyncClient() as client: response = await client.get(url, headers=headers, timeout=10) - - if response.status_code != 200: - raise ToolExecutionError( - f"Failed to look up user during execution of '{lookup_single_user_by_username.__name__}' tool. Request returned an error: {response.status_code} {response.text}" - ) - + if response.status_code == 404: + # User not found + raise RetryableToolError( # noqa: TRY003 + "User not found", + developer_message=f"User with username '{username}' not found.", + additional_prompt_content="Please check the username and try again.", + retry_after_ms=500, # Play nice with X API rate limits + ) + response.raise_for_status() # Parse the response JSON user_data = response.json()["data"] - expand_urls_in_user_description(user_data, delete_entities=False) - expand_urls_in_user_url(user_data, delete_entities=True) + user_data = expand_urls_in_user_description(user_data, delete_entities=False) + user_data = expand_urls_in_user_url(user_data, delete_entities=True) - """ - Example response["data"] structure: - { - "data": { - "verified_type": str, - "public_metrics": { - "followers_count": int, - "following_count": int, - "tweet_count": int, - "listed_count": int, - "like_count": int - }, - "id": str, - "most_recent_tweet_id": str, - "url": str, - "verified": bool, - "location": str, - "description": str, - "name": str, - "username": str, - "profile_image_url": str, - "created_at": str, - "protected": bool - } - } - """ return {"data": user_data} diff --git a/toolkits/x/arcade_x/tools/utils.py b/toolkits/x/arcade_x/tools/utils.py index fb33ead7..7b2f4342 100644 --- a/toolkits/x/arcade_x/tools/utils.py +++ b/toolkits/x/arcade_x/tools/utils.py @@ -1,38 +1,40 @@ from typing import Any +from arcade.sdk import ToolContext +from arcade.sdk.errors import ToolExecutionError + def get_tweet_url(tweet_id: str) -> str: """Get the URL of a tweet given its ID.""" return f"https://x.com/x/status/{tweet_id}" -def parse_search_recent_tweets_response(response_data: Any) -> dict: +def get_headers_with_token(context: ToolContext) -> dict[str, str]: + """Get the headers for a request to the X API.""" + if context.authorization is None or context.authorization.token is None: + raise ToolExecutionError( # noqa: TRY003 + "Missing Token. Authorization is required to post a tweet.", + developer_message="Token is not set in the ToolContext.", + ) + return { + "Authorization": f"Bearer {context.authorization.token}", + "Content-Type": "application/json", + } + + +def parse_search_recent_tweets_response(response_data: dict[str, Any]) -> dict[str, Any]: """ Parses response from the X API search recent tweets endpoint. - Returns a JSON string with the tweets data. - - Example parsed response: - "tweets": [ - { - "author_id": "558248927", - "id": "1838272933141319832", - "edit_history_tweet_ids": [ - "1838272933141319832" - ], - "text": "PR pending on @LangChainAI, will be integrated there soon! https://t.co/DPWd4lccQo", - "tweet_url": "https://x.com/x/status/1838272933141319832", - "author_username": "tomas_hk", - "author_name": "Tomas Hernando Kofman" - }, - ] + Returns the modified response data with added 'tweet_url', 'author_username', and 'author_name'. """ - if not sanity_check_tweets_data(response_data): return {"data": []} + # Add 'tweet_url' to each tweet for tweet in response_data["data"]: tweet["tweet_url"] = get_tweet_url(tweet["id"]) + # Add 'author_username' and 'author_name' to each tweet for tweet_data, user_data in zip(response_data["data"], response_data["includes"]["users"]): tweet_data["author_username"] = user_data["username"] tweet_data["author_name"] = user_data["name"] @@ -40,68 +42,71 @@ def parse_search_recent_tweets_response(response_data: Any) -> dict: return response_data -def sanity_check_tweets_data(tweets_data: dict) -> bool: +def sanity_check_tweets_data(tweets_data: dict[str, Any]) -> bool: """ Sanity check the tweets data. Returns True if the tweets data is valid and contains tweets, False otherwise. """ - if not tweets_data.get("data", []): + if not tweets_data.get("data"): return False - return tweets_data.get("includes", {}).get("users", []) + # prefer clarity over appeasing linter here + if not tweets_data.get("includes", {}).get("users"): # noqa: SIM103 + return False + return True -def expand_urls_in_tweets(tweets_data: list[dict], delete_entities: bool = True) -> None: +def expand_urls_in_tweets( + tweets_data: list[dict[str, Any]], delete_entities: bool = True +) -> list[dict[str, Any]]: """ - Expands the urls in the test of the provided tweets. - X shortens urls, and consequently, this can cause language models to hallucinate. - See more about X's link shortner at https://help.x.com/en/using-x/url-shortener + Returns a new list of tweets with expanded URLs. """ + new_tweets = [] for tweet_data in tweets_data: - if "entities" in tweet_data and "urls" in tweet_data["entities"]: - for url_entity in tweet_data["entities"]["urls"]: + new_tweet = tweet_data.copy() + if "entities" in new_tweet and "urls" in new_tweet["entities"]: + for url_entity in new_tweet["entities"]["urls"]: short_url = url_entity["url"] expanded_url = url_entity["expanded_url"] - tweet_data["text"] = tweet_data["text"].replace(short_url, expanded_url) + new_tweet["text"] = new_tweet["text"].replace(short_url, expanded_url) if delete_entities: - tweet_data.pop( - "entities", None - ) # Now that we've expanded the urls in the tweet, we no longer need the entities + new_tweet.pop("entities", None) + new_tweets.append(new_tweet) + return new_tweets -def expand_urls_in_user_description(user_data: dict, delete_entities: bool = True) -> None: +def expand_urls_in_user_description(user_data: dict, delete_entities: bool = True) -> dict: """ - Expands the urls in the description of the provided user. - X shortens urls, and consequently, this can cause language models to hallucinate. - See more about X's link shortner at https://help.x.com/en/using-x/url-shortener + Returns a new user data dict with expanded URLs in the description. """ - description_urls = user_data.get("entities", {}).get("description", {}).get("urls", []) - description = user_data.get("description", "") + new_user_data = user_data.copy() + description_urls = new_user_data.get("entities", {}).get("description", {}).get("urls", []) + description = new_user_data.get("description", "") for url_info in description_urls: t_co_link = url_info["url"] expanded_url = url_info["expanded_url"] description = description.replace(t_co_link, expanded_url) - user_data["description"] = description + new_user_data["description"] = description if delete_entities: - # Entities is no longer needed now that we have expanded the t.co links - user_data.pop("entities", None) + new_user_data.pop("entities", None) + return new_user_data -def expand_urls_in_user_url(user_data: dict, delete_entities: bool = True) -> None: +def expand_urls_in_user_url(user_data: dict, delete_entities: bool = True) -> dict: """ - Expands the urls in the url section of the provided user. - X shortens urls, and consequently, this can cause language models to hallucinate. - See more about X's link shortner at https://help.x.com/en/using-x/url-shortener + Returns a new user data dict with expanded URLs in the URL field. """ - url_urls = user_data.get("entities", {}).get("url", {}).get("urls", []) - url = user_data.get("url", "") + new_user_data = user_data.copy() + url_urls = new_user_data.get("entities", {}).get("url", {}).get("urls", []) + url = new_user_data.get("url", "") for url_info in url_urls: t_co_link = url_info["url"] expanded_url = url_info["expanded_url"] url = url.replace(t_co_link, expanded_url) - user_data["url"] = url + new_user_data["url"] = url if delete_entities: - # Entities is no longer needed now that we have expanded the t.co links - user_data.pop("entities", None) + new_user_data.pop("entities", None) + return new_user_data diff --git a/toolkits/x/conftest.py b/toolkits/x/conftest.py new file mode 100644 index 00000000..0a842143 --- /dev/null +++ b/toolkits/x/conftest.py @@ -0,0 +1,17 @@ +import pytest +from arcade.sdk import ToolContext + + +@pytest.fixture +def tool_context(): + """Fixture for the ToolContext with mock authorization.""" + return ToolContext(authorization={"token": "test_token", "user_id": "test_user"}) + + +@pytest.fixture +def mock_httpx_client(mocker): + """Fixture to mock the httpx.AsyncClient.""" + # Mock the AsyncClient context manager + mock_client = mocker.patch("httpx.AsyncClient", autospec=True) + async_mock_client = mock_client.return_value.__aenter__.return_value + return async_mock_client diff --git a/toolkits/x/evals/eval_x_tools.py b/toolkits/x/evals/eval_x_tools.py index 31a03f28..1eda2506 100644 --- a/toolkits/x/evals/eval_x_tools.py +++ b/toolkits/x/evals/eval_x_tools.py @@ -1,19 +1,21 @@ -import arcade_x -from arcade_x.tools.tweets import post_tweet - -# TODO -# delete_tweet_by_id, -# search_recent_tweets_by_keywords, -# search_recent_tweets_by_username, -# from arcade_x.tools.users import lookup_single_user_by_username from arcade.sdk import ToolCatalog from arcade.sdk.eval import ( + BinaryCritic, EvalRubric, EvalSuite, - SimilarityCritic, tool_eval, ) +import arcade_x +from arcade_x.tools.tweets import ( + delete_tweet_by_id, + lookup_tweet_by_id, + post_tweet, + search_recent_tweets_by_keywords, + search_recent_tweets_by_username, +) +from arcade_x.tools.users import lookup_single_user_by_username + # Evaluation rubric rubric = EvalRubric( fail_threshold=0.7, @@ -31,7 +33,10 @@ def x_eval_suite() -> EvalSuite: suite = EvalSuite( name="X Tools Evaluation Suite", - system_message="You are an AI assistant with access to the X (Twitter) tools. Use them to help answer the user's X-related tasks/questions.", + system_message=( + "You are an AI assistant with access to the X (Twitter) tools. Use them to " + "help answer the user's X-related tasks/questions." + ), catalog=catalog, rubric=rubric, ) @@ -39,7 +44,10 @@ def x_eval_suite() -> EvalSuite: # Add cases suite.add_case( name="Post a tweet", - user_message="Send out a tweet that says 'Hello World! Exciting stuff is happening over at Arcade AI!'", + user_message=( + "Send out a tweet that says 'Hello World! Exciting stuff is happening over " + "at Arcade AI!'" + ), expected_tool_calls=[ ( post_tweet, @@ -47,11 +55,108 @@ def x_eval_suite() -> EvalSuite: ) ], critics=[ - SimilarityCritic( + BinaryCritic( critic_field="tweet_text", weight=1.0, - similarity_threshold=0.9, ), ], ) + + suite.add_case( + name="Delete a tweet by ID", + user_message="Please delete the tweet with ID '148975632'.", + expected_tool_calls=[ + ( + delete_tweet_by_id, + {"tweet_id": "148975632"}, + ) + ], + critics=[ + BinaryCritic( + critic_field="tweet_id", + weight=1.0, + ), + ], + ) + + suite.add_case( + name="Search recent tweets by username", + user_message="Show me the recent tweets from 'elonmusk'.", + expected_tool_calls=[ + ( + search_recent_tweets_by_username, + {"username": "elonmusk", "max_results": 10}, + ) + ], + critics=[ + BinaryCritic( + critic_field="username", + weight=1.0, + ), + ], + ) + + suite.add_case( + name="Lookup user by username", + user_message="Can you get information about the user '@jack'?", + expected_tool_calls=[ + ( + lookup_single_user_by_username, + {"username": "jack"}, + ) + ], + critics=[ + BinaryCritic( + critic_field="username", + weight=1.0, + ), + ], + ) + + # Add a case for searching recent tweets by keywords + suite.add_case( + name="Search recent tweets by keywords", + user_message="Find recent tweets containing 'Arcade AI'.", + expected_tool_calls=[ + ( + search_recent_tweets_by_keywords, + { + "keywords": [], + "phrases": ["Arcade AI"], + "max_results": 10, + }, + ) + ], + critics=[ + BinaryCritic( + critic_field="keywords", + weight=0.1, + ), + BinaryCritic( + critic_field="phrases", + weight=0.9, + ), + ], + ) + + # Extend the case to test lookup_tweet_by_id + suite.extend_case( + name="Lookup tweet by ID", + user_message="Can you provide details about the tweet with ID '123456789'?", + expected_tool_calls=[ + ( + lookup_tweet_by_id, + { + "tweet_id": "123456789", + }, + ) + ], + critics=[ + BinaryCritic( + critic_field="tweet_id", + weight=1.0, + ), + ], + ) + return suite diff --git a/toolkits/x/pyproject.toml b/toolkits/x/pyproject.toml index 5e511422..5263024c 100644 --- a/toolkits/x/pyproject.toml +++ b/toolkits/x/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "arcade_x" -version = "0.1.0" +version = "0.1.3" description = "LLM tools for interacting with X (Twitter)" authors = ["Arcade AI "] @@ -8,10 +8,35 @@ authors = ["Arcade AI "] python = "^3.10" arcade-ai = "0.1.*" httpx = "^0.27.2" - [tool.poetry.dev-dependencies] pytest = "^8.3.0" +pytest-cov = "^4.0.0" +pytest-asyncio = "^0.24.0" +pytest-mock = "^3.11.1" +mypy = "^1.5.1" +pre-commit = "^3.4.0" +tox = "^4.11.1" +ruff = "^0.7.4" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.mypy] +files = ["arcade_x/**/*.py"] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.coverage.report] +skip_empty = true diff --git a/toolkits/x/tests/test_tweets.py b/toolkits/x/tests/test_tweets.py new file mode 100644 index 00000000..8db5533c --- /dev/null +++ b/toolkits/x/tests/test_tweets.py @@ -0,0 +1,190 @@ +from unittest.mock import MagicMock + +import httpx +import pytest +from arcade.sdk.errors import RetryableToolError, ToolExecutionError + +from arcade_x.tools.tweets import ( + delete_tweet_by_id, + lookup_tweet_by_id, + post_tweet, + search_recent_tweets_by_keywords, + search_recent_tweets_by_username, +) +from arcade_x.tools.utils import get_tweet_url + + +@pytest.mark.asyncio +async def test_post_tweet_success(tool_context, mock_httpx_client): + """Test successful posting of a tweet.""" + # Mock response for a successful tweet post + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = {"data": {"id": "1234567890"}} + mock_httpx_client.post.return_value = mock_response + + tweet_text = "Hello, world!" + result = await post_tweet(tool_context, tweet_text) + + expected_url = get_tweet_url("1234567890") + assert result == f"Tweet with id 1234567890 posted successfully. URL: {expected_url}" + mock_httpx_client.post.assert_called_once() + + +@pytest.mark.asyncio +async def test_post_tweet_failure(tool_context, mock_httpx_client): + """Test failure when posting a tweet due to API error.""" + # Mock response for a failed tweet post + mock_response = httpx.HTTPStatusError( + "Bad Request", request=MagicMock(), response=MagicMock(status_code=400) + ) + mock_httpx_client.post.side_effect = mock_response + + tweet_text = "Hello, world!" + with pytest.raises(ToolExecutionError): + await post_tweet(tool_context, tweet_text) + + mock_httpx_client.post.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_tweet_by_id_success(tool_context, mock_httpx_client): + """Test successful deletion of a tweet by ID.""" + # Mock response for a successful tweet deletion + mock_response = MagicMock() + mock_response.status_code = 200 + mock_httpx_client.delete.return_value = mock_response + + tweet_id = "1234567890" + result = await delete_tweet_by_id(tool_context, tweet_id) + + assert result == f"Tweet with id {tweet_id} deleted successfully." + mock_httpx_client.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_tweet_by_id_failure(tool_context, mock_httpx_client): + """Test failure when deleting a tweet due to API error.""" + # Mock response for a failed tweet deletion + mock_response = httpx.HTTPStatusError( + "Internal Server Error", request=MagicMock(), response=MagicMock(status_code=404) + ) + mock_httpx_client.delete.side_effect = mock_response + + tweet_id = "1234567890" + with pytest.raises(ToolExecutionError): + await delete_tweet_by_id(tool_context, tweet_id) + + mock_httpx_client.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_search_recent_tweets_by_username_success(tool_context, mock_httpx_client): + """Test successful search of recent tweets by username.""" + # Mock response for a successful tweet search + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "1234567890", + "text": "Test tweet", + "entities": { + "urls": [ + {"url": "https://t.co/short", "expanded_url": "https://example.com/long"} + ] + }, + } + ], + "includes": {"users": [{"id": "0987654321", "name": "Test User", "username": "testuser"}]}, + } + mock_httpx_client.get.return_value = mock_response + + username = "testuser" + result = await search_recent_tweets_by_username(tool_context, username) + + assert "data" in result + assert len(result["data"]) == 1 + assert result["data"][0]["text"] == "Test tweet" + mock_httpx_client.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_search_recent_tweets_by_username_failure(tool_context, mock_httpx_client): + """Test failure when searching tweets due to API error.""" + # Mock response for a failed tweet search + mock_response = httpx.HTTPStatusError( + "Internal Server Error", request=MagicMock(), response=MagicMock(status_code=500) + ) + mock_httpx_client.get.side_effect = mock_response + + username = "testuser" + with pytest.raises(ToolExecutionError): + await search_recent_tweets_by_username(tool_context, username) + + mock_httpx_client.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_search_recent_tweets_by_keywords_success(tool_context, mock_httpx_client): + """Test successful search of recent tweets by keywords.""" + # Mock response for a successful keyword search + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [{"id": "1234567890", "text": "Keyword tweet", "entities": {}}], + "includes": {"users": [{"id": "0987654321", "name": "Test User", "username": "testuser"}]}, + } + mock_httpx_client.get.return_value = mock_response + + keywords = ["test", "keyword"] + result = await search_recent_tweets_by_keywords(tool_context, keywords=keywords) + + assert "data" in result + assert len(result["data"]) == 1 + assert result["data"][0]["text"] == "Keyword tweet" + mock_httpx_client.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_search_recent_tweets_by_keywords_no_input(tool_context): + """Test error when no keywords or phrases are provided.""" + with pytest.raises(RetryableToolError) as exc_info: + await search_recent_tweets_by_keywords(tool_context) + + assert "No keywords or phrases provided" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_lookup_tweet_by_id_success(tool_context, mock_httpx_client): + """Test successful lookup of a tweet by ID.""" + # Use MagicMock for the response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": {"id": "1234567890", "text": "Lookup tweet", "entities": {}} + } + mock_httpx_client.get.return_value = mock_response + + tweet_id = "1234567890" + result = await lookup_tweet_by_id(tool_context, tweet_id) + + assert "data" in result + assert result["data"]["text"] == "Lookup tweet" + mock_httpx_client.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_lookup_tweet_by_id_failure(tool_context, mock_httpx_client): + """Test failure when looking up a tweet due to API error.""" + # Mock response for a failed tweet lookup + mock_response = httpx.HTTPStatusError( + "Not Found", request=MagicMock(), response=MagicMock(status_code=404) + ) + mock_httpx_client.get.side_effect = mock_response + + tweet_id = "1234567890" + with pytest.raises(ToolExecutionError): + await lookup_tweet_by_id(tool_context, tweet_id) + + mock_httpx_client.get.assert_called_once() diff --git a/toolkits/x/tests/test_users.py b/toolkits/x/tests/test_users.py new file mode 100644 index 00000000..6115ff91 --- /dev/null +++ b/toolkits/x/tests/test_users.py @@ -0,0 +1,78 @@ +from unittest.mock import MagicMock + +import httpx +import pytest +from arcade.sdk.errors import ToolExecutionError + +from arcade_x.tools.users import lookup_single_user_by_username + + +@pytest.mark.asyncio +async def test_lookup_single_user_by_username_success(tool_context, mock_httpx_client): + """Test successful lookup of a user by username.""" + # Mock response for a successful user lookup + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "id": "1234567890", + "name": "Test User", + "username": "testuser", + "description": "This is a test user", + # Additional fields can be added here as needed + } + } + mock_httpx_client.get.return_value = mock_response + + username = "testuser" + result = await lookup_single_user_by_username(tool_context, username) + + assert "data" in result + assert result["data"]["username"] == "testuser" + assert result["data"]["name"] == "Test User" + mock_httpx_client.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_lookup_single_user_by_username_user_not_found(tool_context, mock_httpx_client): + """Test behavior when looking up user fails due to API error""" + # Mock response for user not found + mock_response = httpx.HTTPStatusError( + "Not Found", request=MagicMock(), response=MagicMock(status_code=404) + ) + mock_httpx_client.get.side_effect = mock_response + + username = "nonexistentuser" + with pytest.raises(ToolExecutionError): + await lookup_single_user_by_username(tool_context, username) + + mock_httpx_client.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_lookup_single_user_by_username_api_error(tool_context, mock_httpx_client): + """Test behavior when API returns an error other than 404.""" + # Mock response for API error + mock_response = httpx.HTTPStatusError( + "Internal Server Error", request=MagicMock(), response=MagicMock(status_code=500) + ) + mock_httpx_client.get.side_effect = mock_response + + username = "testuser" + with pytest.raises(ToolExecutionError): + await lookup_single_user_by_username(tool_context, username) + + mock_httpx_client.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_lookup_single_user_by_username_network_error(tool_context, mock_httpx_client): + """Test behavior when there is a network error during the request.""" + # Mock client.get to raise an HTTPError + mock_httpx_client.get.side_effect = httpx.HTTPError("Network Error") + + username = "testuser" + with pytest.raises(ToolExecutionError): + await lookup_single_user_by_username(tool_context, username) + + mock_httpx_client.get.assert_called_once()