diff --git a/toolkits/x/arcade_x/tools/tweets.py b/toolkits/x/arcade_x/tools/tweets.py index 9aa71034..302aecb6 100644 --- a/toolkits/x/arcade_x/tools/tweets.py +++ b/toolkits/x/arcade_x/tools/tweets.py @@ -6,6 +6,7 @@ from arcade.sdk.auth import X from arcade.sdk.errors import RetryableToolError from arcade_x.tools.utils import ( + expand_attached_media, expand_long_tweet, expand_urls_in_tweets, get_headers_with_token, @@ -81,13 +82,13 @@ async def search_recent_tweets_by_username( max(max_results, 10), 100 ), # X API does not allow 'max_results' less than 10 or greater than 100 "next_token": next_token, + "expansions": "author_id", + "user.fields": "id,name,username,entities", + "tweet.fields": "entities,note_tweet", } - params = remove_none_values(params) + params = expand_attached_media(remove_none_values(params)) - url = ( - "https://api.x.com/2/tweets/search/recent?" - "expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities,note_tweet" - ) + url = f"{TWEETS_URL}/search/recent" async with httpx.AsyncClient() as client: response = await client.get(url, headers=headers, params=params, timeout=10) @@ -151,13 +152,13 @@ async def search_recent_tweets_by_keywords( max(max_results, 10), 100 ), # X API does not allow 'max_results' less than 10 or greater than 100 "next_token": next_token, + "expansions": "author_id", + "user.fields": "id,name,username,entities", + "tweet.fields": "entities,note_tweet", } - params = remove_none_values(params) + params = expand_attached_media(remove_none_values(params)) - url = ( - "https://api.x.com/2/tweets/search/recent?" - "expansions=author_id&user.fields=id,name,username,entities&tweet.fields=entities,note_tweet" - ) + url = f"{TWEETS_URL}/search/recent" async with httpx.AsyncClient() as client: response = await client.get(url, headers=headers, params=params, timeout=10) @@ -192,6 +193,8 @@ async def lookup_tweet_by_id( "user.fields": "id,name,username,entities", "tweet.fields": "entities,note_tweet", } + params = expand_attached_media(params) + url = f"{TWEETS_URL}/{tweet_id}" async with httpx.AsyncClient() as client: diff --git a/toolkits/x/arcade_x/tools/utils.py b/toolkits/x/arcade_x/tools/utils.py index e7f4c6ca..7c9c9e8c 100644 --- a/toolkits/x/arcade_x/tools/utils.py +++ b/toolkits/x/arcade_x/tools/utils.py @@ -137,3 +137,24 @@ def remove_none_values(params: dict) -> dict: A new dictionary with None values removed """ return {k: v for k, v in params.items() if v is not None} + + +def expand_attached_media(params: dict) -> dict: + """ + Include attached media metadata in the request parameters. + """ + params["expansions"] += ",attachments.media_keys" + params["tweet.fields"] += ",attachments" + params["media.fields"] = ",".join([ + # media_key, url and type are returned by default, added here for clarity + "media_key", + "url", + "type", + "duration_ms", + "height", + "width", + "preview_image_url", + "alt_text", + "public_metrics", + ]) + return params diff --git a/toolkits/x/evals/eval_x_tools.py b/toolkits/x/evals/eval_x_tools.py index b344d4a5..5882a60c 100644 --- a/toolkits/x/evals/eval_x_tools.py +++ b/toolkits/x/evals/eval_x_tools.py @@ -80,9 +80,7 @@ def x_eval_suite() -> EvalSuite: expected_tool_calls=[ ExpectedToolCall( func=post_tweet, - args={ - "tweet_text": "Hello World! Exciting stuff is happening over at Arcade AI!", - }, + args={"tweet_text": "Hello World! Exciting stuff is happening over at Arcade AI!"}, ) ], critics=[ @@ -139,7 +137,7 @@ def x_eval_suite() -> EvalSuite: "max_results": 42, "next_token": "b26v89c19zqg8o3frr3tekall7a7ooom3sctaw30rz62l", }, - ) + ), ], critics=[ BinaryCritic( @@ -164,7 +162,7 @@ def x_eval_suite() -> EvalSuite: ExpectedToolCall( func=lookup_single_user_by_username, args={"username": "jack"}, - ) + ), ], critics=[ BinaryCritic( @@ -186,7 +184,7 @@ def x_eval_suite() -> EvalSuite: "phrases": ["Arcade AI"], "max_results": 10, }, - ) + ), ], critics=[ BinaryCritic( @@ -208,7 +206,7 @@ def x_eval_suite() -> EvalSuite: ExpectedToolCall( func=lookup_tweet_by_id, args={"tweet_id": "123456789"}, - ) + ), ], critics=[ BinaryCritic( diff --git a/toolkits/x/pyproject.toml b/toolkits/x/pyproject.toml index 5ac6acc0..f31a84ba 100644 --- a/toolkits/x/pyproject.toml +++ b/toolkits/x/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "arcade_x" -version = "0.1.8" +version = "0.1.9" description = "LLM tools for interacting with X (Twitter)" authors = ["Arcade AI "] diff --git a/toolkits/x/tests/test_tweets.py b/toolkits/x/tests/test_tweets.py index 4ec3a4a8..eb1e98f9 100644 --- a/toolkits/x/tests/test_tweets.py +++ b/toolkits/x/tests/test_tweets.py @@ -117,7 +117,12 @@ async def test_search_recent_tweets_by_username_success(tool_context, mock_httpx }, } ], - "includes": {"users": [{"id": "0987654321", "name": "Test User", "username": "testuser"}]}, + "includes": { + "users": [{"id": "0987654321", "name": "Test User", "username": "testuser"}], + "media": [ + {"media_key": "1234567890", "type": "photo", "url": "https://example.com/photo.jpg"} + ], + }, } mock_httpx_client.get.return_value = mock_response @@ -127,6 +132,12 @@ async def test_search_recent_tweets_by_username_success(tool_context, mock_httpx assert "data" in result assert len(result["data"]) == 1 assert result["data"][0]["text"] == full_tweet_text + + assert "includes" in result + assert "media" in result["includes"] + assert len(result["includes"]["media"]) == 1 + assert result["includes"]["media"][0]["url"] == "https://example.com/photo.jpg" + mock_httpx_client.get.assert_called_once() @@ -168,16 +179,27 @@ async def test_search_recent_tweets_by_keywords_success(tool_context, mock_httpx "entities": {}, } ], - "includes": {"users": [{"id": "0987654321", "name": "Test User", "username": "testuser"}]}, + "includes": { + "users": [{"id": "0987654321", "name": "Test User", "username": "testuser"}], + "media": [ + {"media_key": "1234567890", "type": "photo", "url": "https://example.com/photo.jpg"} + ], + }, } mock_httpx_client.get.return_value = mock_response keywords = ["test", "keyword"] - result = await search_recent_tweets_by_keywords(tool_context, keywords=keywords) + result = await search_recent_tweets_by_keywords(context=tool_context, keywords=keywords) assert "data" in result assert len(result["data"]) == 1 assert result["data"][0]["text"] == full_tweet_text + + assert "includes" in result + assert "media" in result["includes"] + assert len(result["includes"]["media"]) == 1 + assert result["includes"]["media"][0]["url"] == "https://example.com/photo.jpg" + mock_httpx_client.get.assert_called_once() @@ -207,7 +229,12 @@ async def test_lookup_tweet_by_id_success(tool_context, mock_httpx_client): }, "text": truncated_tweet_text, "entities": {}, - } + }, + "includes": { + "media": [ + {"media_key": "1234567890", "type": "photo", "url": "https://example.com/photo.jpg"} + ] + }, } mock_httpx_client.get.return_value = mock_response @@ -216,6 +243,12 @@ async def test_lookup_tweet_by_id_success(tool_context, mock_httpx_client): assert "data" in result assert result["data"]["text"] == full_tweet_text + + assert "includes" in result + assert "media" in result["includes"] + assert len(result["includes"]["media"]) == 1 + assert result["includes"]["media"][0]["url"] == "https://example.com/photo.jpg" + mock_httpx_client.get.assert_called_once()