Add lookup_tweet_by_id to X Toolkit (#165)
This PR introduces the `lookup_tweet_by_id` tool to the X toolkit,
enabling users to retrieve tweet details by tweet ID. This enhancement
extends the toolkit's capabilities, allowing for more comprehensive
interactions with the X (Twitter) API.
**Key Changes:**
- **Added `lookup_tweet_by_id` Tool:**
- Implemented the `lookup_tweet_by_id` function in `tools/tweets.py`,
which allows users to fetch tweet information using a tweet ID.
- Included error handling for API response codes and expanded URLs in
tweets to assist language models in avoiding hallucinations due to
shortened URLs.
- **Enhanced Toolkit Structure:**
- Added several configuration files to the X toolkit to establish a
standardized project structure, which in the future will be generated by
`arcade new`. These include:
- `.pre-commit-config.yaml`: Defines pre-commit hooks for code quality
checks.
- `.ruff.toml`: Configuration for the Ruff linter.
- `LICENSE`: MIT License file for the toolkit.
- `Makefile`: Contains common commands for building, testing, and
linting the toolkit.
- **Updated Makefile:**
- Added `make check-toolkits` command to the top-level `Makefile`. This
command runs code quality tools for each toolkit that contains a
`Makefile`.
**Additional Notes:**
- **Tests:**
- Added unit tests for the new `lookup_tweet_by_id` tool in
`tests/test_tweets.py`.
- Included tests for the user lookup functionality in
`tests/test_users.py`.
- **Linting and Code Quality:**
- Configured pre-commit hooks and Ruff linter to enforce code standards.
- Updated the `pyproject.toml` file with development dependencies for
testing and linting.
-
---------
Co-authored-by: Eric Gustin <eric@arcade-ai.com>
This commit is contained in:
parent
cf6a2969bf
commit
bebfcab1e9
14 changed files with 764 additions and 163 deletions
15
Makefile
15
Makefile
|
|
@ -25,6 +25,19 @@ check: ## Run code quality tools.
|
|||
@echo "🚀 Static type checking: Running mypy"
|
||||
@cd arcade && poetry run mypy $(git ls-files '*.py')
|
||||
|
||||
|
||||
.PHONY: check-toolkits
|
||||
check-toolkits: ## Run code quality tools for each toolkit that has a Makefile
|
||||
@echo "🚀 Running 'make check' in each toolkit with a Makefile"
|
||||
@for dir in toolkits/*/ ; do \
|
||||
if [ -f "$$dir/Makefile" ]; then \
|
||||
echo "🛠️ Checking toolkit $$dir"; \
|
||||
(cd "$$dir" && make check); \
|
||||
else \
|
||||
echo "🛠️ Skipping toolkit $$dir (no Makefile found)"; \
|
||||
fi; \
|
||||
done
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
|
|
@ -144,4 +157,6 @@ help:
|
|||
@echo "🛠️ Arcade AI Dev Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
|
||||
|
||||
.DEFAULT_GOAL := help
|
||||
|
|
|
|||
18
toolkits/x/.pre-commit-config.yaml
Normal file
18
toolkits/x/.pre-commit-config.yaml
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
files: ^./
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: "v4.4.0"
|
||||
hooks:
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
44
toolkits/x/.ruff.toml
Normal file
44
toolkits/x/.ruff.toml
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
target-version = "py39"
|
||||
line-length = 100
|
||||
fix = true
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
# flake8-2020
|
||||
"YTT",
|
||||
# flake8-bandit
|
||||
"S",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-builtins
|
||||
"A",
|
||||
# flake8-comprehensions
|
||||
"C4",
|
||||
# flake8-debugger
|
||||
"T10",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# mccabe
|
||||
"C90",
|
||||
# pycodestyle
|
||||
"E", "W",
|
||||
# pyflakes
|
||||
"F",
|
||||
# pygrep-hooks
|
||||
"PGH",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# ruff
|
||||
"RUF",
|
||||
# tryceratops
|
||||
"TRY",
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"**/tests/*" = ["S101"]
|
||||
|
||||
[format]
|
||||
preview = true
|
||||
skip-magic-trailing-comma = false
|
||||
21
toolkits/x/LICENSE
Normal file
21
toolkits/x/LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2024, arcadeai
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
53
toolkits/x/Makefile
Normal file
53
toolkits/x/Makefile
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
.PHONY: help
|
||||
|
||||
help:
|
||||
@echo "🛠️ x Commands:\n"
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: install
|
||||
install: ## Install the poetry environment and install the pre-commit hooks
|
||||
@echo "📦 Checking if Poetry is installed"
|
||||
@if ! command -v poetry &> /dev/null; then \
|
||||
echo "📦 Installing Poetry with pip"; \
|
||||
pip install poetry; \
|
||||
else \
|
||||
echo "📦 Poetry is already installed"; \
|
||||
fi
|
||||
@echo "🚀 Installing package in development mode with all extras"
|
||||
poetry install --all-extras
|
||||
|
||||
.PHONY: build
|
||||
build: clean-build ## Build wheel file using poetry
|
||||
@echo "🚀 Creating wheel file"
|
||||
poetry build
|
||||
|
||||
.PHONY: clean-build
|
||||
clean-build: ## clean build artifacts
|
||||
@echo "🗑️ Cleaning dist directory"
|
||||
rm -rf dist
|
||||
|
||||
.PHONY: test
|
||||
test: ## Test the code with pytest
|
||||
@echo "🚀 Testing code: Running pytest"
|
||||
@poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
||||
|
||||
.PHONY: coverage
|
||||
coverage: ## Generate coverage report
|
||||
@echo "coverage report"
|
||||
coverage report
|
||||
@echo "Generating coverage report"
|
||||
coverage html
|
||||
|
||||
.PHONY: bump-version
|
||||
bump-version: ## Bump the version in the pyproject.toml file
|
||||
@echo "🚀 Bumping version in pyproject.toml"
|
||||
poetry version patch
|
||||
|
||||
.PHONY: check
|
||||
check: ## Run code quality tools.
|
||||
@echo "🚀 Checking Poetry lock file consistency with 'pyproject.toml': Running poetry check --lock"
|
||||
@poetry check --lock
|
||||
@echo "🚀 Linting code: Running pre-commit"
|
||||
@poetry run pre-commit run -a
|
||||
@echo "🚀 Static type checking: Running mypy"
|
||||
@poetry run mypy --config-file=pyproject.toml
|
||||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from arcade.sdk import ToolContext, tool
|
||||
from arcade.sdk.auth import X
|
||||
from arcade.sdk.errors import ToolExecutionError
|
||||
from arcade.sdk.errors import RetryableToolError
|
||||
|
||||
from arcade_x.tools.utils import (
|
||||
expand_urls_in_tweets,
|
||||
get_headers_with_token,
|
||||
get_tweet_url,
|
||||
parse_search_recent_tweets_response,
|
||||
)
|
||||
|
|
@ -14,7 +15,10 @@ from arcade_x.tools.utils import (
|
|||
TWEETS_URL = "https://api.x.com/2/tweets"
|
||||
|
||||
|
||||
# Manage Tweets Tools. See developer docs for additional available parameters: https://developer.x.com/en/docs/x-api/tweets/manage-tweets/api-reference
|
||||
# Manage Tweets Tools. See developer docs for additional available parameters:
|
||||
# https://developer.x.com/en/docs/x-api/tweets/manage-tweets/api-reference
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=X(
|
||||
scopes=["tweet.read", "tweet.write", "users.read"],
|
||||
|
|
@ -26,19 +30,12 @@ async def post_tweet(
|
|||
) -> Annotated[str, "Success string and the URL of the tweet"]:
|
||||
"""Post a tweet to X (Twitter)."""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {context.authorization.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers = get_headers_with_token(context)
|
||||
payload = {"text": tweet_text}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(TWEETS_URL, headers=headers, json=payload, timeout=10)
|
||||
|
||||
if response.status_code != 201:
|
||||
raise ToolExecutionError(
|
||||
f"Failed to post a tweet during execution of '{post_tweet.__name__}' tool. Request returned an error: {response.status_code} {response.text}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
tweet_id = response.json()["data"]["id"]
|
||||
return f"Tweet with id {tweet_id} posted successfully. URL: {get_tweet_url(tweet_id)}"
|
||||
|
|
@ -51,16 +48,12 @@ async def delete_tweet_by_id(
|
|||
) -> Annotated[str, "Success string confirming the tweet deletion"]:
|
||||
"""Delete a tweet on X (Twitter)."""
|
||||
|
||||
headers = {"Authorization": f"Bearer {context.authorization.token}"}
|
||||
headers = get_headers_with_token(context)
|
||||
url = f"{TWEETS_URL}/{tweet_id}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(url, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolExecutionError(
|
||||
f"Failed to delete the tweet during execution of '{delete_tweet_by_id.__name__}' tool. Request returned an error: {response.status_code} {response.text}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return f"Tweet with id {tweet_id} deleted successfully."
|
||||
|
||||
|
|
@ -72,33 +65,33 @@ async def search_recent_tweets_by_username(
|
|||
max_results: Annotated[
|
||||
int, "The maximum number of results to return. Cannot be less than 10"
|
||||
] = 10,
|
||||
) -> Annotated[dict, "Dictionary containing the search results"]:
|
||||
"""Search for recent tweets (last 7 days) on X (Twitter) by username. Includes replies and reposts."""
|
||||
) -> Annotated[dict[str, Any], "Dictionary containing the search results"]:
|
||||
"""Search for recent tweets (last 7 days) on X (Twitter) by username.
|
||||
Includes replies and reposts."""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {context.authorization.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
params = {
|
||||
headers = get_headers_with_token(context)
|
||||
params: dict[str, int | str] = {
|
||||
"query": f"from:{username}",
|
||||
"max_results": max(max_results, 10), # X API does not allow 'max_results' less than 10
|
||||
}
|
||||
url = "https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities"
|
||||
url = (
|
||||
"https://api.x.com/2/tweets/search/recent?"
|
||||
"expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities"
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolExecutionError(
|
||||
f"Failed to search recent tweets during execution of '{search_recent_tweets_by_username.__name__}' tool. Request returned an error: {response.status_code} {response.text}"
|
||||
)
|
||||
response_data: dict[str, Any] = response.json()
|
||||
|
||||
response_data = response.json()
|
||||
# Expand the URLs that are in the tweets
|
||||
response_data["data"] = expand_urls_in_tweets(
|
||||
response_data.get("data", []), delete_entities=True
|
||||
)
|
||||
|
||||
# Expand the urls that are in the tweets
|
||||
expand_urls_in_tweets(response_data.get("data", []), delete_entities=True)
|
||||
|
||||
parse_search_recent_tweets_response(response_data)
|
||||
# Parse the response data
|
||||
response_data = parse_search_recent_tweets_response(response_data)
|
||||
|
||||
return response_data
|
||||
|
||||
|
|
@ -115,44 +108,81 @@ async def search_recent_tweets_by_keywords(
|
|||
max_results: Annotated[
|
||||
int, "The maximum number of results to return. Cannot be less than 10"
|
||||
] = 10,
|
||||
) -> Annotated[dict, "Dictionary containing the search results"]:
|
||||
) -> Annotated[dict[str, Any], "Dictionary containing the search results"]:
|
||||
"""
|
||||
Search for recent tweets (last 7 days) on X (Twitter) by required keywords and phrases. Includes replies and reposts
|
||||
One of the following input parametersMUST be provided: keywords, phrases
|
||||
Search for recent tweets (last 7 days) on X (Twitter) by required keywords and phrases.
|
||||
Includes replies and reposts.
|
||||
One of the following input parameters MUST be provided: keywords, phrases
|
||||
"""
|
||||
|
||||
if not any([keywords, phrases]):
|
||||
raise ValueError(
|
||||
"At least one of keywords or phrases must be provided to the '{search_recent_tweets_by_keywords.__name__}' tool."
|
||||
raise RetryableToolError( # noqa: TRY003
|
||||
"No keywords or phrases provided",
|
||||
developer_message="Predicted inputs didn't contain any keywords or phrases",
|
||||
additional_prompt_content="Please provide at least one keyword or phrase for search",
|
||||
retry_after_ms=500, # Play nice with X API rate limits
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {context.authorization.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers = get_headers_with_token(context)
|
||||
|
||||
query = "".join([f'"{phrase}" ' for phrase in (phrases or [])])
|
||||
if keywords:
|
||||
query += " ".join(keywords or [])
|
||||
|
||||
params = {
|
||||
"query": query,
|
||||
params: dict[str, int | str] = {
|
||||
"query": query.strip(),
|
||||
"max_results": max(max_results, 10), # X API does not allow 'max_results' less than 10
|
||||
}
|
||||
url = "https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities"
|
||||
url = (
|
||||
"https://api.x.com/2/tweets/search/recent?"
|
||||
"expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities"
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolExecutionError(
|
||||
f"Failed to search recent tweets during execution of '{search_recent_tweets_by_keywords.__name__}' tool. Request returned an error: {response.status_code} {response.text}"
|
||||
)
|
||||
response_data: dict[str, Any] = response.json()
|
||||
|
||||
response_data = response.json()
|
||||
# Expand the URLs that are in the tweets
|
||||
response_data["data"] = expand_urls_in_tweets(
|
||||
response_data.get("data", []), delete_entities=True
|
||||
)
|
||||
|
||||
# Expand the urls that are in the tweets
|
||||
expand_urls_in_tweets(response_data.get("data", []), delete_entities=True)
|
||||
|
||||
parse_search_recent_tweets_response(response_data)
|
||||
# Parse the response data
|
||||
response_data = parse_search_recent_tweets_response(response_data)
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
@tool(requires_auth=X(scopes=["tweet.read", "users.read"]))
|
||||
async def lookup_tweet_by_id(
|
||||
context: ToolContext,
|
||||
tweet_id: Annotated[str, "The ID of the tweet you want to look up"],
|
||||
) -> Annotated[dict[str, Any], "Dictionary containing the tweet data"]:
|
||||
"""Look up a tweet on X (Twitter) by tweet ID."""
|
||||
|
||||
headers = get_headers_with_token(context)
|
||||
params = {
|
||||
"expansions": "author_id",
|
||||
"user.fields": "id,name,username,entities",
|
||||
"tweet.fields": "entities",
|
||||
}
|
||||
url = f"{TWEETS_URL}/{tweet_id}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data: dict[str, Any] = response.json()
|
||||
|
||||
# Get the tweet data
|
||||
tweet_data = response_data.get("data")
|
||||
if tweet_data:
|
||||
# Expand the URLs that are in the tweet
|
||||
expanded_tweet_list = expand_urls_in_tweets([tweet_data], delete_entities=True)
|
||||
response_data["data"] = expanded_tweet_list[0]
|
||||
else:
|
||||
response_data["data"] = {}
|
||||
|
||||
return response_data
|
||||
|
|
|
|||
|
|
@ -1,14 +1,20 @@
|
|||
from typing import Annotated
|
||||
|
||||
import httpx
|
||||
|
||||
from arcade.sdk import ToolContext, tool
|
||||
from arcade.sdk.auth import X
|
||||
from arcade.sdk.errors import ToolExecutionError
|
||||
from arcade_x.tools.utils import expand_urls_in_user_description, expand_urls_in_user_url
|
||||
from arcade.sdk.errors import RetryableToolError
|
||||
|
||||
from arcade_x.tools.utils import (
|
||||
expand_urls_in_user_description,
|
||||
expand_urls_in_user_url,
|
||||
get_headers_with_token,
|
||||
)
|
||||
|
||||
# Users Lookup Tools. See developer docs for additional available query parameters:
|
||||
# https://developer.x.com/en/docs/x-api/users/lookup/api-reference
|
||||
|
||||
|
||||
# Users Lookup Tools. See developer docs for additional available query parameters: https://developer.x.com/en/docs/x-api/users/lookup/api-reference
|
||||
@tool(requires_auth=X(scopes=["users.read", "tweet.read"]))
|
||||
async def lookup_single_user_by_username(
|
||||
context: ToolContext,
|
||||
|
|
@ -16,49 +22,43 @@ async def lookup_single_user_by_username(
|
|||
) -> Annotated[dict, "User information including id, name, username, and description"]:
|
||||
"""Look up a user on X (Twitter) by their username."""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {context.authorization.token}",
|
||||
}
|
||||
url = f"https://api.x.com/2/users/by/username/{username}?user.fields=created_at,description,id,location,most_recent_tweet_id,name,pinned_tweet_id,profile_image_url,protected,public_metrics,url,username,verified,verified_type,withheld,entities"
|
||||
headers = get_headers_with_token(context)
|
||||
|
||||
user_fields = ",".join([
|
||||
"created_at",
|
||||
"description",
|
||||
"id",
|
||||
"location",
|
||||
"most_recent_tweet_id",
|
||||
"name",
|
||||
"pinned_tweet_id",
|
||||
"profile_image_url",
|
||||
"protected",
|
||||
"public_metrics",
|
||||
"url",
|
||||
"username",
|
||||
"verified",
|
||||
"verified_type",
|
||||
"withheld",
|
||||
"entities",
|
||||
])
|
||||
url = f"https://api.x.com/2/users/by/username/{username}?user.fields={user_fields}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolExecutionError(
|
||||
f"Failed to look up user during execution of '{lookup_single_user_by_username.__name__}' tool. Request returned an error: {response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
# User not found
|
||||
raise RetryableToolError( # noqa: TRY003
|
||||
"User not found",
|
||||
developer_message=f"User with username '{username}' not found.",
|
||||
additional_prompt_content="Please check the username and try again.",
|
||||
retry_after_ms=500, # Play nice with X API rate limits
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Parse the response JSON
|
||||
user_data = response.json()["data"]
|
||||
|
||||
expand_urls_in_user_description(user_data, delete_entities=False)
|
||||
expand_urls_in_user_url(user_data, delete_entities=True)
|
||||
user_data = expand_urls_in_user_description(user_data, delete_entities=False)
|
||||
user_data = expand_urls_in_user_url(user_data, delete_entities=True)
|
||||
|
||||
"""
|
||||
Example response["data"] structure:
|
||||
{
|
||||
"data": {
|
||||
"verified_type": str,
|
||||
"public_metrics": {
|
||||
"followers_count": int,
|
||||
"following_count": int,
|
||||
"tweet_count": int,
|
||||
"listed_count": int,
|
||||
"like_count": int
|
||||
},
|
||||
"id": str,
|
||||
"most_recent_tweet_id": str,
|
||||
"url": str,
|
||||
"verified": bool,
|
||||
"location": str,
|
||||
"description": str,
|
||||
"name": str,
|
||||
"username": str,
|
||||
"profile_image_url": str,
|
||||
"created_at": str,
|
||||
"protected": bool
|
||||
}
|
||||
}
|
||||
"""
|
||||
return {"data": user_data}
|
||||
|
|
|
|||
|
|
@ -1,38 +1,40 @@
|
|||
from typing import Any
|
||||
|
||||
from arcade.sdk import ToolContext
|
||||
from arcade.sdk.errors import ToolExecutionError
|
||||
|
||||
|
||||
def get_tweet_url(tweet_id: str) -> str:
|
||||
"""Get the URL of a tweet given its ID."""
|
||||
return f"https://x.com/x/status/{tweet_id}"
|
||||
|
||||
|
||||
def parse_search_recent_tweets_response(response_data: Any) -> dict:
|
||||
def get_headers_with_token(context: ToolContext) -> dict[str, str]:
|
||||
"""Get the headers for a request to the X API."""
|
||||
if context.authorization is None or context.authorization.token is None:
|
||||
raise ToolExecutionError( # noqa: TRY003
|
||||
"Missing Token. Authorization is required to post a tweet.",
|
||||
developer_message="Token is not set in the ToolContext.",
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {context.authorization.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def parse_search_recent_tweets_response(response_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Parses response from the X API search recent tweets endpoint.
|
||||
Returns a JSON string with the tweets data.
|
||||
|
||||
Example parsed response:
|
||||
"tweets": [
|
||||
{
|
||||
"author_id": "558248927",
|
||||
"id": "1838272933141319832",
|
||||
"edit_history_tweet_ids": [
|
||||
"1838272933141319832"
|
||||
],
|
||||
"text": "PR pending on @LangChainAI, will be integrated there soon! https://t.co/DPWd4lccQo",
|
||||
"tweet_url": "https://x.com/x/status/1838272933141319832",
|
||||
"author_username": "tomas_hk",
|
||||
"author_name": "Tomas Hernando Kofman"
|
||||
},
|
||||
]
|
||||
Returns the modified response data with added 'tweet_url', 'author_username', and 'author_name'.
|
||||
"""
|
||||
|
||||
if not sanity_check_tweets_data(response_data):
|
||||
return {"data": []}
|
||||
|
||||
# Add 'tweet_url' to each tweet
|
||||
for tweet in response_data["data"]:
|
||||
tweet["tweet_url"] = get_tweet_url(tweet["id"])
|
||||
|
||||
# Add 'author_username' and 'author_name' to each tweet
|
||||
for tweet_data, user_data in zip(response_data["data"], response_data["includes"]["users"]):
|
||||
tweet_data["author_username"] = user_data["username"]
|
||||
tweet_data["author_name"] = user_data["name"]
|
||||
|
|
@ -40,68 +42,71 @@ def parse_search_recent_tweets_response(response_data: Any) -> dict:
|
|||
return response_data
|
||||
|
||||
|
||||
def sanity_check_tweets_data(tweets_data: dict) -> bool:
|
||||
def sanity_check_tweets_data(tweets_data: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Sanity check the tweets data.
|
||||
Returns True if the tweets data is valid and contains tweets, False otherwise.
|
||||
"""
|
||||
if not tweets_data.get("data", []):
|
||||
if not tweets_data.get("data"):
|
||||
return False
|
||||
return tweets_data.get("includes", {}).get("users", [])
|
||||
# prefer clarity over appeasing linter here
|
||||
if not tweets_data.get("includes", {}).get("users"): # noqa: SIM103
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def expand_urls_in_tweets(tweets_data: list[dict], delete_entities: bool = True) -> None:
|
||||
def expand_urls_in_tweets(
|
||||
tweets_data: list[dict[str, Any]], delete_entities: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Expands the urls in the test of the provided tweets.
|
||||
X shortens urls, and consequently, this can cause language models to hallucinate.
|
||||
See more about X's link shortner at https://help.x.com/en/using-x/url-shortener
|
||||
Returns a new list of tweets with expanded URLs.
|
||||
"""
|
||||
new_tweets = []
|
||||
for tweet_data in tweets_data:
|
||||
if "entities" in tweet_data and "urls" in tweet_data["entities"]:
|
||||
for url_entity in tweet_data["entities"]["urls"]:
|
||||
new_tweet = tweet_data.copy()
|
||||
if "entities" in new_tweet and "urls" in new_tweet["entities"]:
|
||||
for url_entity in new_tweet["entities"]["urls"]:
|
||||
short_url = url_entity["url"]
|
||||
expanded_url = url_entity["expanded_url"]
|
||||
tweet_data["text"] = tweet_data["text"].replace(short_url, expanded_url)
|
||||
new_tweet["text"] = new_tweet["text"].replace(short_url, expanded_url)
|
||||
|
||||
if delete_entities:
|
||||
tweet_data.pop(
|
||||
"entities", None
|
||||
) # Now that we've expanded the urls in the tweet, we no longer need the entities
|
||||
new_tweet.pop("entities", None)
|
||||
new_tweets.append(new_tweet)
|
||||
return new_tweets
|
||||
|
||||
|
||||
def expand_urls_in_user_description(user_data: dict, delete_entities: bool = True) -> None:
|
||||
def expand_urls_in_user_description(user_data: dict, delete_entities: bool = True) -> dict:
|
||||
"""
|
||||
Expands the urls in the description of the provided user.
|
||||
X shortens urls, and consequently, this can cause language models to hallucinate.
|
||||
See more about X's link shortner at https://help.x.com/en/using-x/url-shortener
|
||||
Returns a new user data dict with expanded URLs in the description.
|
||||
"""
|
||||
description_urls = user_data.get("entities", {}).get("description", {}).get("urls", [])
|
||||
description = user_data.get("description", "")
|
||||
new_user_data = user_data.copy()
|
||||
description_urls = new_user_data.get("entities", {}).get("description", {}).get("urls", [])
|
||||
description = new_user_data.get("description", "")
|
||||
for url_info in description_urls:
|
||||
t_co_link = url_info["url"]
|
||||
expanded_url = url_info["expanded_url"]
|
||||
description = description.replace(t_co_link, expanded_url)
|
||||
user_data["description"] = description
|
||||
new_user_data["description"] = description
|
||||
|
||||
if delete_entities:
|
||||
# Entities is no longer needed now that we have expanded the t.co links
|
||||
user_data.pop("entities", None)
|
||||
new_user_data.pop("entities", None)
|
||||
return new_user_data
|
||||
|
||||
|
||||
def expand_urls_in_user_url(user_data: dict, delete_entities: bool = True) -> None:
|
||||
def expand_urls_in_user_url(user_data: dict, delete_entities: bool = True) -> dict:
|
||||
"""
|
||||
Expands the urls in the url section of the provided user.
|
||||
X shortens urls, and consequently, this can cause language models to hallucinate.
|
||||
See more about X's link shortner at https://help.x.com/en/using-x/url-shortener
|
||||
Returns a new user data dict with expanded URLs in the URL field.
|
||||
"""
|
||||
url_urls = user_data.get("entities", {}).get("url", {}).get("urls", [])
|
||||
url = user_data.get("url", "")
|
||||
new_user_data = user_data.copy()
|
||||
url_urls = new_user_data.get("entities", {}).get("url", {}).get("urls", [])
|
||||
url = new_user_data.get("url", "")
|
||||
for url_info in url_urls:
|
||||
t_co_link = url_info["url"]
|
||||
expanded_url = url_info["expanded_url"]
|
||||
url = url.replace(t_co_link, expanded_url)
|
||||
user_data["url"] = url
|
||||
new_user_data["url"] = url
|
||||
|
||||
if delete_entities:
|
||||
# Entities is no longer needed now that we have expanded the t.co links
|
||||
user_data.pop("entities", None)
|
||||
new_user_data.pop("entities", None)
|
||||
return new_user_data
|
||||
|
|
|
|||
17
toolkits/x/conftest.py
Normal file
17
toolkits/x/conftest.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import pytest
|
||||
from arcade.sdk import ToolContext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_context():
|
||||
"""Fixture for the ToolContext with mock authorization."""
|
||||
return ToolContext(authorization={"token": "test_token", "user_id": "test_user"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client(mocker):
|
||||
"""Fixture to mock the httpx.AsyncClient."""
|
||||
# Mock the AsyncClient context manager
|
||||
mock_client = mocker.patch("httpx.AsyncClient", autospec=True)
|
||||
async_mock_client = mock_client.return_value.__aenter__.return_value
|
||||
return async_mock_client
|
||||
|
|
@ -1,19 +1,21 @@
|
|||
import arcade_x
|
||||
from arcade_x.tools.tweets import post_tweet
|
||||
|
||||
# TODO
|
||||
# delete_tweet_by_id,
|
||||
# search_recent_tweets_by_keywords,
|
||||
# search_recent_tweets_by_username,
|
||||
# from arcade_x.tools.users import lookup_single_user_by_username
|
||||
from arcade.sdk import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
import arcade_x
|
||||
from arcade_x.tools.tweets import (
|
||||
delete_tweet_by_id,
|
||||
lookup_tweet_by_id,
|
||||
post_tweet,
|
||||
search_recent_tweets_by_keywords,
|
||||
search_recent_tweets_by_username,
|
||||
)
|
||||
from arcade_x.tools.users import lookup_single_user_by_username
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.7,
|
||||
|
|
@ -31,7 +33,10 @@ def x_eval_suite() -> EvalSuite:
|
|||
|
||||
suite = EvalSuite(
|
||||
name="X Tools Evaluation Suite",
|
||||
system_message="You are an AI assistant with access to the X (Twitter) tools. Use them to help answer the user's X-related tasks/questions.",
|
||||
system_message=(
|
||||
"You are an AI assistant with access to the X (Twitter) tools. Use them to "
|
||||
"help answer the user's X-related tasks/questions."
|
||||
),
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
|
@ -39,7 +44,10 @@ def x_eval_suite() -> EvalSuite:
|
|||
# Add cases
|
||||
suite.add_case(
|
||||
name="Post a tweet",
|
||||
user_message="Send out a tweet that says 'Hello World! Exciting stuff is happening over at Arcade AI!'",
|
||||
user_message=(
|
||||
"Send out a tweet that says 'Hello World! Exciting stuff is happening over "
|
||||
"at Arcade AI!'"
|
||||
),
|
||||
expected_tool_calls=[
|
||||
(
|
||||
post_tweet,
|
||||
|
|
@ -47,11 +55,108 @@ def x_eval_suite() -> EvalSuite:
|
|||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(
|
||||
BinaryCritic(
|
||||
critic_field="tweet_text",
|
||||
weight=1.0,
|
||||
similarity_threshold=0.9,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Delete a tweet by ID",
|
||||
user_message="Please delete the tweet with ID '148975632'.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
delete_tweet_by_id,
|
||||
{"tweet_id": "148975632"},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(
|
||||
critic_field="tweet_id",
|
||||
weight=1.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Search recent tweets by username",
|
||||
user_message="Show me the recent tweets from 'elonmusk'.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_recent_tweets_by_username,
|
||||
{"username": "elonmusk", "max_results": 10},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(
|
||||
critic_field="username",
|
||||
weight=1.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Lookup user by username",
|
||||
user_message="Can you get information about the user '@jack'?",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
lookup_single_user_by_username,
|
||||
{"username": "jack"},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(
|
||||
critic_field="username",
|
||||
weight=1.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Add a case for searching recent tweets by keywords
|
||||
suite.add_case(
|
||||
name="Search recent tweets by keywords",
|
||||
user_message="Find recent tweets containing 'Arcade AI'.",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_recent_tweets_by_keywords,
|
||||
{
|
||||
"keywords": [],
|
||||
"phrases": ["Arcade AI"],
|
||||
"max_results": 10,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(
|
||||
critic_field="keywords",
|
||||
weight=0.1,
|
||||
),
|
||||
BinaryCritic(
|
||||
critic_field="phrases",
|
||||
weight=0.9,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Extend the case to test lookup_tweet_by_id
|
||||
suite.extend_case(
|
||||
name="Lookup tweet by ID",
|
||||
user_message="Can you provide details about the tweet with ID '123456789'?",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
lookup_tweet_by_id,
|
||||
{
|
||||
"tweet_id": "123456789",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(
|
||||
critic_field="tweet_id",
|
||||
weight=1.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "arcade_x"
|
||||
version = "0.1.0"
|
||||
version = "0.1.3"
|
||||
description = "LLM tools for interacting with X (Twitter)"
|
||||
authors = ["Arcade AI <dev@arcade-ai.com>"]
|
||||
|
||||
|
|
@ -8,10 +8,35 @@ authors = ["Arcade AI <dev@arcade-ai.com>"]
|
|||
python = "^3.10"
|
||||
arcade-ai = "0.1.*"
|
||||
httpx = "^0.27.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^8.3.0"
|
||||
pytest-cov = "^4.0.0"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
pytest-mock = "^3.11.1"
|
||||
mypy = "^1.5.1"
|
||||
pre-commit = "^3.4.0"
|
||||
tox = "^4.11.1"
|
||||
ruff = "^0.7.4"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.mypy]
|
||||
files = ["arcade_x/**/*.py"]
|
||||
python_version = "3.10"
|
||||
disallow_untyped_defs = "True"
|
||||
disallow_any_unimported = "True"
|
||||
no_implicit_optional = "True"
|
||||
check_untyped_defs = "True"
|
||||
warn_return_any = "True"
|
||||
warn_unused_ignores = "True"
|
||||
show_error_codes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.coverage.report]
|
||||
skip_empty = true
|
||||
|
|
|
|||
190
toolkits/x/tests/test_tweets.py
Normal file
190
toolkits/x/tests/test_tweets.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from arcade.sdk.errors import RetryableToolError, ToolExecutionError
|
||||
|
||||
from arcade_x.tools.tweets import (
|
||||
delete_tweet_by_id,
|
||||
lookup_tweet_by_id,
|
||||
post_tweet,
|
||||
search_recent_tweets_by_keywords,
|
||||
search_recent_tweets_by_username,
|
||||
)
|
||||
from arcade_x.tools.utils import get_tweet_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_tweet_success(tool_context, mock_httpx_client):
|
||||
"""Test successful posting of a tweet."""
|
||||
# Mock response for a successful tweet post
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"data": {"id": "1234567890"}}
|
||||
mock_httpx_client.post.return_value = mock_response
|
||||
|
||||
tweet_text = "Hello, world!"
|
||||
result = await post_tweet(tool_context, tweet_text)
|
||||
|
||||
expected_url = get_tweet_url("1234567890")
|
||||
assert result == f"Tweet with id 1234567890 posted successfully. URL: {expected_url}"
|
||||
mock_httpx_client.post.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_tweet_failure(tool_context, mock_httpx_client):
|
||||
"""Test failure when posting a tweet due to API error."""
|
||||
# Mock response for a failed tweet post
|
||||
mock_response = httpx.HTTPStatusError(
|
||||
"Bad Request", request=MagicMock(), response=MagicMock(status_code=400)
|
||||
)
|
||||
mock_httpx_client.post.side_effect = mock_response
|
||||
|
||||
tweet_text = "Hello, world!"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await post_tweet(tool_context, tweet_text)
|
||||
|
||||
mock_httpx_client.post.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_tweet_by_id_success(tool_context, mock_httpx_client):
|
||||
"""Test successful deletion of a tweet by ID."""
|
||||
# Mock response for a successful tweet deletion
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_httpx_client.delete.return_value = mock_response
|
||||
|
||||
tweet_id = "1234567890"
|
||||
result = await delete_tweet_by_id(tool_context, tweet_id)
|
||||
|
||||
assert result == f"Tweet with id {tweet_id} deleted successfully."
|
||||
mock_httpx_client.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_tweet_by_id_failure(tool_context, mock_httpx_client):
|
||||
"""Test failure when deleting a tweet due to API error."""
|
||||
# Mock response for a failed tweet deletion
|
||||
mock_response = httpx.HTTPStatusError(
|
||||
"Internal Server Error", request=MagicMock(), response=MagicMock(status_code=404)
|
||||
)
|
||||
mock_httpx_client.delete.side_effect = mock_response
|
||||
|
||||
tweet_id = "1234567890"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await delete_tweet_by_id(tool_context, tweet_id)
|
||||
|
||||
mock_httpx_client.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_recent_tweets_by_username_success(tool_context, mock_httpx_client):
|
||||
"""Test successful search of recent tweets by username."""
|
||||
# Mock response for a successful tweet search
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": "1234567890",
|
||||
"text": "Test tweet",
|
||||
"entities": {
|
||||
"urls": [
|
||||
{"url": "https://t.co/short", "expanded_url": "https://example.com/long"}
|
||||
]
|
||||
},
|
||||
}
|
||||
],
|
||||
"includes": {"users": [{"id": "0987654321", "name": "Test User", "username": "testuser"}]},
|
||||
}
|
||||
mock_httpx_client.get.return_value = mock_response
|
||||
|
||||
username = "testuser"
|
||||
result = await search_recent_tweets_by_username(tool_context, username)
|
||||
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 1
|
||||
assert result["data"][0]["text"] == "Test tweet"
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_recent_tweets_by_username_failure(tool_context, mock_httpx_client):
|
||||
"""Test failure when searching tweets due to API error."""
|
||||
# Mock response for a failed tweet search
|
||||
mock_response = httpx.HTTPStatusError(
|
||||
"Internal Server Error", request=MagicMock(), response=MagicMock(status_code=500)
|
||||
)
|
||||
mock_httpx_client.get.side_effect = mock_response
|
||||
|
||||
username = "testuser"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await search_recent_tweets_by_username(tool_context, username)
|
||||
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_recent_tweets_by_keywords_success(tool_context, mock_httpx_client):
|
||||
"""Test successful search of recent tweets by keywords."""
|
||||
# Mock response for a successful keyword search
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"id": "1234567890", "text": "Keyword tweet", "entities": {}}],
|
||||
"includes": {"users": [{"id": "0987654321", "name": "Test User", "username": "testuser"}]},
|
||||
}
|
||||
mock_httpx_client.get.return_value = mock_response
|
||||
|
||||
keywords = ["test", "keyword"]
|
||||
result = await search_recent_tweets_by_keywords(tool_context, keywords=keywords)
|
||||
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 1
|
||||
assert result["data"][0]["text"] == "Keyword tweet"
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_recent_tweets_by_keywords_no_input(tool_context):
|
||||
"""Test error when no keywords or phrases are provided."""
|
||||
with pytest.raises(RetryableToolError) as exc_info:
|
||||
await search_recent_tweets_by_keywords(tool_context)
|
||||
|
||||
assert "No keywords or phrases provided" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_tweet_by_id_success(tool_context, mock_httpx_client):
|
||||
"""Test successful lookup of a tweet by ID."""
|
||||
# Use MagicMock for the response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": {"id": "1234567890", "text": "Lookup tweet", "entities": {}}
|
||||
}
|
||||
mock_httpx_client.get.return_value = mock_response
|
||||
|
||||
tweet_id = "1234567890"
|
||||
result = await lookup_tweet_by_id(tool_context, tweet_id)
|
||||
|
||||
assert "data" in result
|
||||
assert result["data"]["text"] == "Lookup tweet"
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_tweet_by_id_failure(tool_context, mock_httpx_client):
|
||||
"""Test failure when looking up a tweet due to API error."""
|
||||
# Mock response for a failed tweet lookup
|
||||
mock_response = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=MagicMock(status_code=404)
|
||||
)
|
||||
mock_httpx_client.get.side_effect = mock_response
|
||||
|
||||
tweet_id = "1234567890"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await lookup_tweet_by_id(tool_context, tweet_id)
|
||||
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
78
toolkits/x/tests/test_users.py
Normal file
78
toolkits/x/tests/test_users.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from arcade.sdk.errors import ToolExecutionError
|
||||
|
||||
from arcade_x.tools.users import lookup_single_user_by_username
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_single_user_by_username_success(tool_context, mock_httpx_client):
|
||||
"""Test successful lookup of a user by username."""
|
||||
# Mock response for a successful user lookup
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": {
|
||||
"id": "1234567890",
|
||||
"name": "Test User",
|
||||
"username": "testuser",
|
||||
"description": "This is a test user",
|
||||
# Additional fields can be added here as needed
|
||||
}
|
||||
}
|
||||
mock_httpx_client.get.return_value = mock_response
|
||||
|
||||
username = "testuser"
|
||||
result = await lookup_single_user_by_username(tool_context, username)
|
||||
|
||||
assert "data" in result
|
||||
assert result["data"]["username"] == "testuser"
|
||||
assert result["data"]["name"] == "Test User"
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_single_user_by_username_user_not_found(tool_context, mock_httpx_client):
|
||||
"""Test behavior when looking up user fails due to API error"""
|
||||
# Mock response for user not found
|
||||
mock_response = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=MagicMock(status_code=404)
|
||||
)
|
||||
mock_httpx_client.get.side_effect = mock_response
|
||||
|
||||
username = "nonexistentuser"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await lookup_single_user_by_username(tool_context, username)
|
||||
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_single_user_by_username_api_error(tool_context, mock_httpx_client):
|
||||
"""Test behavior when API returns an error other than 404."""
|
||||
# Mock response for API error
|
||||
mock_response = httpx.HTTPStatusError(
|
||||
"Internal Server Error", request=MagicMock(), response=MagicMock(status_code=500)
|
||||
)
|
||||
mock_httpx_client.get.side_effect = mock_response
|
||||
|
||||
username = "testuser"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await lookup_single_user_by_username(tool_context, username)
|
||||
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_single_user_by_username_network_error(tool_context, mock_httpx_client):
|
||||
"""Test behavior when there is a network error during the request."""
|
||||
# Mock client.get to raise an HTTPError
|
||||
mock_httpx_client.get.side_effect = httpx.HTTPError("Network Error")
|
||||
|
||||
username = "testuser"
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await lookup_single_user_by_username(tool_context, username)
|
||||
|
||||
mock_httpx_client.get.assert_called_once()
|
||||
Loading…
Reference in a new issue