Add Gmail Thread Tools (#159)
# PR Description
1. This PR adds three new tools:
- GetThread (by ID)
- ListThreads
- SearchThreads
2. This PR updates the return type for various Gmail tools from str to
dict.
3. This PR adds evals and tests for the added tools
This commit is contained in:
parent
82afd7ec70
commit
2798cc0820
5 changed files with 445 additions and 69 deletions
|
|
@ -1,5 +1,4 @@
|
|||
import base64
|
||||
import json
|
||||
from email.message import EmailMessage
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Annotated, Optional
|
||||
|
|
@ -20,6 +19,7 @@ from arcade_google.tools.utils import (
|
|||
get_sent_email_url,
|
||||
parse_draft_email,
|
||||
parse_email,
|
||||
remove_none_values,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ async def send_email(
|
|||
recipient: Annotated[str, "The recipient of the email"],
|
||||
cc: Annotated[Optional[list[str]], "CC recipients of the email"] = None,
|
||||
bcc: Annotated[Optional[list[str]], "BCC recipients of the email"] = None,
|
||||
) -> Annotated[str, "A confirmation message with the sent email ID and URL"]:
|
||||
) -> Annotated[dict, "A dictionary containing the sent email details"]:
|
||||
"""
|
||||
Send an email using the Gmail API.
|
||||
"""
|
||||
|
|
@ -61,7 +61,10 @@ async def send_email(
|
|||
|
||||
# Send the email
|
||||
sent_message = service.users().messages().send(userId="me", body=email).execute()
|
||||
return f"Email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
|
||||
|
||||
email = parse_email(sent_message)
|
||||
email["url"] = get_sent_email_url(sent_message["id"])
|
||||
return email
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
@ -71,7 +74,7 @@ async def send_email(
|
|||
)
|
||||
async def send_draft_email(
|
||||
context: ToolContext, email_id: Annotated[str, "The ID of the draft to send"]
|
||||
) -> Annotated[str, "A confirmation message with the sent email ID and URL"]:
|
||||
) -> Annotated[dict, "A dictionary containing the sent email details"]:
|
||||
"""
|
||||
Send a draft email using the Gmail API.
|
||||
"""
|
||||
|
|
@ -82,10 +85,9 @@ async def send_draft_email(
|
|||
# Send the draft email
|
||||
sent_message = service.users().drafts().send(userId="me", body={"id": email_id}).execute()
|
||||
|
||||
# Construct the URL to the sent email
|
||||
return (
|
||||
f"Draft email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
|
||||
)
|
||||
email = parse_email(sent_message)
|
||||
email["url"] = get_sent_email_url(sent_message["id"])
|
||||
return email
|
||||
|
||||
|
||||
# Draft Management Tools
|
||||
|
|
@ -101,7 +103,7 @@ async def write_draft_email(
|
|||
recipient: Annotated[str, "The recipient of the draft email"],
|
||||
cc: Annotated[Optional[list[str]], "CC recipients of the draft email"] = None,
|
||||
bcc: Annotated[Optional[list[str]], "BCC recipients of the draft email"] = None,
|
||||
) -> Annotated[str, "A confirmation message with the draft email ID and URL"]:
|
||||
) -> Annotated[dict, "A dictionary containing the created draft email details"]:
|
||||
"""
|
||||
Compose a new email draft using the Gmail API.
|
||||
"""
|
||||
|
|
@ -123,9 +125,9 @@ async def write_draft_email(
|
|||
draft = {"message": {"raw": raw_message}}
|
||||
|
||||
draft_message = service.users().drafts().create(userId="me", body=draft).execute()
|
||||
return (
|
||||
f"Draft email with ID {draft_message['id']} created: {get_draft_url(draft_message['id'])}"
|
||||
)
|
||||
email = parse_draft_email(draft_message)
|
||||
email["url"] = get_draft_url(draft_message["id"])
|
||||
return email
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
@ -141,7 +143,7 @@ async def update_draft_email(
|
|||
recipient: Annotated[str, "The recipient of the draft email"],
|
||||
cc: Annotated[Optional[list[str]], "CC recipients of the draft email"] = None,
|
||||
bcc: Annotated[Optional[list[str]], "BCC recipients of the draft email"] = None,
|
||||
) -> Annotated[str, "A confirmation message with the updated draft email ID and URL"]:
|
||||
) -> Annotated[dict, "A dictionary containing the updated draft email details"]:
|
||||
"""
|
||||
Update an existing email draft using the Gmail API.
|
||||
"""
|
||||
|
|
@ -166,7 +168,10 @@ async def update_draft_email(
|
|||
updated_draft_message = (
|
||||
service.users().drafts().update(userId="me", id=draft_email_id, body=draft).execute()
|
||||
)
|
||||
return f"Draft email with ID {updated_draft_message['id']} updated: {get_draft_url(updated_draft_message['id'])}"
|
||||
|
||||
email = parse_draft_email(updated_draft_message)
|
||||
email["url"] = get_draft_url(updated_draft_message["id"])
|
||||
return email
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
@ -198,7 +203,7 @@ async def delete_draft_email(
|
|||
)
|
||||
async def trash_email(
|
||||
context: ToolContext, email_id: Annotated[str, "The ID of the email to trash"]
|
||||
) -> Annotated[str, "A confirmation message with the trashed email ID and URL"]:
|
||||
) -> Annotated[dict, "A dictionary containing the trashed email details"]:
|
||||
"""
|
||||
Move an email to the trash folder using the Gmail API.
|
||||
"""
|
||||
|
|
@ -207,9 +212,11 @@ async def trash_email(
|
|||
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||
|
||||
# Trash the email
|
||||
service.users().messages().trash(userId="me", id=email_id).execute()
|
||||
trashed_email = service.users().messages().trash(userId="me", id=email_id).execute()
|
||||
|
||||
return f"Email with ID {email_id} trashed successfully: {get_email_in_trash_url(email_id)}"
|
||||
email = parse_email(trashed_email)
|
||||
email["url"] = get_email_in_trash_url(trashed_email["id"])
|
||||
return email
|
||||
|
||||
|
||||
# Draft Search Tools
|
||||
|
|
@ -221,7 +228,7 @@ async def trash_email(
|
|||
async def list_draft_emails(
|
||||
context: ToolContext,
|
||||
n_drafts: Annotated[int, "Number of draft emails to read"] = 5,
|
||||
) -> Annotated[str, "A JSON string containing a list of draft email details and their IDs"]:
|
||||
) -> Annotated[dict, "A dictionary containing a list of draft email details"]:
|
||||
"""
|
||||
Lists draft emails in the user's draft mailbox using the Gmail API.
|
||||
"""
|
||||
|
|
@ -245,7 +252,7 @@ async def list_draft_emails(
|
|||
except Exception as e:
|
||||
print(f"Error reading draft email {draft_id}: {e}")
|
||||
|
||||
return json.dumps({"emails": emails})
|
||||
return {"emails": emails}
|
||||
|
||||
|
||||
# Email Search Tools
|
||||
|
|
@ -263,11 +270,11 @@ async def list_emails_by_header(
|
|||
date_range: Annotated[Optional[DateRange], "The date range of the email"] = None,
|
||||
limit: Annotated[Optional[int], "The maximum number of emails to return"] = 25,
|
||||
) -> Annotated[
|
||||
str, "A JSON string containing a list of email details matching the search criteria"
|
||||
dict, "A dictionary containing a list of email details matching the search criteria"
|
||||
]:
|
||||
"""
|
||||
Search for emails by header using the Gmail API.
|
||||
At least one of the following parametersMUST be provided: sender, recipient, subject, body.
|
||||
At least one of the following parameters MUST be provided: sender, recipient, subject, body.
|
||||
"""
|
||||
if not any([sender, recipient, subject, body]):
|
||||
raise RetryableToolError(
|
||||
|
|
@ -281,10 +288,10 @@ async def list_emails_by_header(
|
|||
messages = fetch_messages(service, query, limit)
|
||||
|
||||
if not messages:
|
||||
return json.dumps({"emails": []})
|
||||
return {"emails": []}
|
||||
|
||||
emails = process_messages(service, messages)
|
||||
return json.dumps({"emails": emails})
|
||||
return {"emails": emails}
|
||||
|
||||
|
||||
def process_messages(service, messages):
|
||||
|
|
@ -307,7 +314,7 @@ def process_messages(service, messages):
|
|||
async def list_emails(
|
||||
context: ToolContext,
|
||||
n_emails: Annotated[int, "Number of emails to read"] = 5,
|
||||
) -> Annotated[str, "A JSON string containing a list of email details"]:
|
||||
) -> Annotated[dict, "A dictionary containing a list of email details"]:
|
||||
"""
|
||||
Read emails from a Gmail account and extract plain text content.
|
||||
"""
|
||||
|
|
@ -329,4 +336,109 @@ async def list_emails(
|
|||
except Exception as e:
|
||||
print(f"Error reading email {msg['id']}: {e}")
|
||||
|
||||
return json.dumps({"emails": emails})
|
||||
return {"emails": emails}
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Google(
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
)
|
||||
async def search_threads(
|
||||
context: ToolContext,
|
||||
page_token: Annotated[
|
||||
Optional[str], "Page token to retrieve a specific page of results in the list"
|
||||
] = None,
|
||||
max_results: Annotated[int, "The maximum number of threads to return"] = 10,
|
||||
include_spam_trash: Annotated[bool, "Whether to include spam and trash in the results"] = False,
|
||||
label_ids: Annotated[Optional[list[str]], "The IDs of labels to filter by"] = None,
|
||||
sender: Annotated[Optional[str], "The name or email address of the sender of the email"] = None,
|
||||
recipient: Annotated[Optional[str], "The name or email address of the recipient"] = None,
|
||||
subject: Annotated[Optional[str], "Words to find in the subject of the email"] = None,
|
||||
body: Annotated[Optional[str], "Words to find in the body of the email"] = None,
|
||||
date_range: Annotated[Optional[DateRange], "The date range of the email"] = None,
|
||||
) -> Annotated[dict, "A dictionary containing a list of thread details"]:
|
||||
"""Search for threads in the user's mailbox"""
|
||||
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||
|
||||
query = (
|
||||
build_query_string(sender, recipient, subject, body, date_range)
|
||||
if any([sender, recipient, subject, body, date_range])
|
||||
else None
|
||||
)
|
||||
|
||||
params = {
|
||||
"userId": "me",
|
||||
"maxResults": min(max_results, 500),
|
||||
"pageToken": page_token,
|
||||
"includeSpamTrash": include_spam_trash,
|
||||
"labelIds": label_ids,
|
||||
"q": query,
|
||||
}
|
||||
params = remove_none_values(params)
|
||||
|
||||
threads = []
|
||||
next_page_token = None
|
||||
# Paginate through thread pages until we have the desired number of threads
|
||||
while len(threads) < max_results:
|
||||
response = service.users().threads().list(**params).execute()
|
||||
|
||||
threads.extend(response.get("threads", []))
|
||||
next_page_token = response.get("nextPageToken")
|
||||
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
params["pageToken"] = next_page_token
|
||||
params["maxResults"] = min(max_results - len(threads), 500)
|
||||
|
||||
return {
|
||||
"threads": threads,
|
||||
"num_threads": len(threads),
|
||||
"next_page_token": next_page_token,
|
||||
}
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Google(
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
)
|
||||
async def list_threads(
|
||||
context: ToolContext,
|
||||
page_token: Annotated[
|
||||
Optional[str], "Page token to retrieve a specific page of results in the list"
|
||||
] = None,
|
||||
max_results: Annotated[int, "The maximum number of threads to return"] = 10,
|
||||
include_spam_trash: Annotated[bool, "Whether to include spam and trash in the results"] = False,
|
||||
) -> Annotated[dict, "A dictionary containing a list of thread details"]:
|
||||
"""List threads in the user's mailbox."""
|
||||
return await search_threads(context, page_token, max_results, include_spam_trash)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Google(
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
)
|
||||
async def get_thread(
|
||||
context: ToolContext,
|
||||
thread_id: Annotated[str, "The ID of the thread to retrieve"],
|
||||
metadata_headers: Annotated[
|
||||
Optional[list[str]], "When given and format is METADATA, only include headers specified."
|
||||
] = None,
|
||||
) -> Annotated[dict, "A dictionary containing the thread details"]:
|
||||
"""Get the specified thread by ID."""
|
||||
params = {
|
||||
"userId": "me",
|
||||
"id": thread_id,
|
||||
"format": "full",
|
||||
"metadataHeaders": metadata_headers,
|
||||
}
|
||||
params = remove_none_values(params)
|
||||
|
||||
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||
thread = service.users().threads().get(**params).execute()
|
||||
thread["messages"] = [parse_email(message) for message in thread.get("messages", [])]
|
||||
|
||||
return thread
|
||||
|
|
|
|||
|
|
@ -232,7 +232,9 @@ class SendUpdatesOptions(Enum):
|
|||
EXTERNAL_ONLY = "externalOnly" # Notifications are sent to non-Google Calendar guests only.
|
||||
|
||||
|
||||
# Utils for Google Drive tools
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Google Drive Models and Enums
|
||||
# ---------------------------------------------------------------------------- #
|
||||
class Corpora(str, Enum):
|
||||
"""
|
||||
Bodies of items (files/documents) to which the query applies.
|
||||
|
|
|
|||
|
|
@ -82,16 +82,17 @@ def parse_email(email_data: dict[str, Any]) -> Optional[dict[str, str]]:
|
|||
Optional[Dict[str, str]]: Parsed email details or None if parsing fails.
|
||||
"""
|
||||
try:
|
||||
payload = email_data["payload"]
|
||||
headers = {d["name"].lower(): d["value"] for d in payload["headers"]}
|
||||
payload = email_data.get("payload", {})
|
||||
headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])}
|
||||
|
||||
body_data = _get_email_body(payload)
|
||||
|
||||
return {
|
||||
"id": email_data.get("id", ""),
|
||||
"thread_id": email_data.get("threadId", ""),
|
||||
"from": headers.get("from", ""),
|
||||
"date": headers.get("date", ""),
|
||||
"subject": headers.get("subject", "No subject"),
|
||||
"subject": headers.get("subject", ""),
|
||||
"body": _clean_email_body(body_data) if body_data else "",
|
||||
}
|
||||
except Exception as e:
|
||||
|
|
@ -110,17 +111,18 @@ def parse_draft_email(draft_email_data: dict[str, Any]) -> Optional[dict[str, st
|
|||
Optional[Dict[str, str]]: Parsed draft email details or None if parsing fails.
|
||||
"""
|
||||
try:
|
||||
message = draft_email_data["message"]
|
||||
payload = message["payload"]
|
||||
headers = {d["name"].lower(): d["value"] for d in payload["headers"]}
|
||||
message = draft_email_data.get("message", {})
|
||||
payload = message.get("payload", {})
|
||||
headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])}
|
||||
|
||||
body_data = _get_email_body(payload)
|
||||
|
||||
return {
|
||||
"id": draft_email_data.get("id", ""),
|
||||
"thread_id": draft_email_data.get("threadId", ""),
|
||||
"from": headers.get("from", ""),
|
||||
"date": headers.get("internaldate", ""),
|
||||
"subject": headers.get("subject", "No subject"),
|
||||
"subject": headers.get("subject", ""),
|
||||
"body": _clean_email_body(body_data) if body_data else "",
|
||||
}
|
||||
except Exception as e:
|
||||
|
|
@ -226,7 +228,7 @@ def _update_datetime(day: Day | None, time: TimeSlot | None, time_zone: str) ->
|
|||
|
||||
def build_query_string(sender, recipient, subject, body, date_range):
|
||||
"""
|
||||
Helper function to build a query string for Gmail list_emails_by_header tool.
|
||||
Helper function to build a query string for Gmail list_emails_by_header and search_threads tools.
|
||||
"""
|
||||
query = []
|
||||
if sender:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import arcade_google
|
||||
from arcade_google.tools.gmail import (
|
||||
get_thread,
|
||||
list_threads,
|
||||
search_threads,
|
||||
send_email,
|
||||
)
|
||||
from arcade_google.tools.utils import DateRange
|
||||
|
||||
from arcade.sdk import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
|
|
@ -57,4 +61,100 @@ def gmail_eval_suite() -> EvalSuite:
|
|||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Simple list threads",
|
||||
user_message="Get 42 threads like right now i even wanna see the ones in my trash",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
list_threads,
|
||||
{"max_results": 42, "include_spam_trash": True},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="max_results", weight=0.5),
|
||||
BinaryCritic(critic_field="include_spam_trash", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "list 1 thread"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_X8V5Hw9iJ3wfB8WMZf8omAMi",
|
||||
"type": "function",
|
||||
"function": {"name": "Google_ListThreads", "arguments": '{"max_results":1}'},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": '{"next_page_token":"10321400718999360131","num_threads":1,"threads":[{"historyId":"61691","id":"1934a8f8deccb749","snippet":"Hi Joe, I hope this email finds you well. Thank you for being a part of our community."}]}',
|
||||
"tool_call_id": "call_X8V5Hw9iJ3wfB8WMZf8omAMi",
|
||||
"name": "Google_ListThreads",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Here is one email thread:\n\n- **Snippet:** Hi Joe, I hope this email finds you well. Thank you for being a part of our community.\n- **Thread ID:** 1934a8f8deccb749\n- **History ID:** 61691",
|
||||
},
|
||||
]
|
||||
suite.add_case(
|
||||
name="List threads with history",
|
||||
user_message="Get the next 5 threads",
|
||||
additional_messages=history,
|
||||
expected_tool_calls=[
|
||||
(
|
||||
list_threads,
|
||||
{
|
||||
"max_results": 5,
|
||||
"page_token": "10321400718999360131",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="max_results", weight=0.2),
|
||||
BinaryCritic(critic_field="page_token", weight=0.8),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Search threads",
|
||||
user_message="Search for threads from johndoe@example.com to janedoe@example.com about that talk about 'Arcade AI' from yesterday",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
search_threads,
|
||||
{
|
||||
"sender": "johndoe@example.com",
|
||||
"recipient": "janedoe@example.com",
|
||||
"body": "Arcade AI",
|
||||
"date_range": DateRange.YESTERDAY,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="sender", weight=0.25),
|
||||
BinaryCritic(critic_field="recipient", weight=0.25),
|
||||
SimilarityCritic(critic_field="body", weight=0.25),
|
||||
BinaryCritic(critic_field="date_range", weight=0.25),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get a thread by ID",
|
||||
user_message="Get the thread r-124325435467568867667878874565464564563523424323524235242412",
|
||||
expected_tool_calls=[
|
||||
(
|
||||
get_thread,
|
||||
{
|
||||
"thread_id": "r-124325435467568867667878874565464564563523424323524235242412",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="thread_id", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from arcade_google.tools.gmail import (
|
||||
delete_draft_email,
|
||||
get_thread,
|
||||
list_draft_emails,
|
||||
list_emails,
|
||||
list_emails_by_header,
|
||||
list_threads,
|
||||
search_threads,
|
||||
send_draft_email,
|
||||
send_email,
|
||||
trash_email,
|
||||
|
|
@ -40,8 +42,11 @@ async def test_send_email(mock_build, mock_context):
|
|||
recipient="test@example.com",
|
||||
)
|
||||
|
||||
assert "Email with ID" in result
|
||||
assert "sent" in result
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert "thread_id" in result
|
||||
assert "subject" in result
|
||||
assert "body" in result
|
||||
|
||||
# Test http error
|
||||
mock_service.users().messages().send().execute.side_effect = HttpError(
|
||||
|
|
@ -72,8 +77,11 @@ async def test_write_draft_email(mock_build, mock_context):
|
|||
recipient="draft@example.com",
|
||||
)
|
||||
|
||||
assert "Draft email with ID" in result
|
||||
assert "created" in result
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert "thread_id" in result
|
||||
assert "subject" in result
|
||||
assert "body" in result
|
||||
|
||||
# Test http error
|
||||
mock_service.users().drafts().create().execute.side_effect = HttpError(
|
||||
|
|
@ -105,8 +113,11 @@ async def test_update_draft_email(mock_build, mock_context):
|
|||
recipient="updated@example.com",
|
||||
)
|
||||
|
||||
assert "Draft email with ID" in result
|
||||
assert "updated" in result
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert "thread_id" in result
|
||||
assert "subject" in result
|
||||
assert "body" in result
|
||||
|
||||
# Test http error
|
||||
mock_service.users().drafts().update().execute.side_effect = HttpError(
|
||||
|
|
@ -133,8 +144,11 @@ async def test_send_draft_email(mock_build, mock_context):
|
|||
# Test happy path
|
||||
result = await send_draft_email(context=mock_context, email_id="draft456")
|
||||
|
||||
assert "Draft email with ID" in result
|
||||
assert "sent" in result
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert "thread_id" in result
|
||||
assert "subject" in result
|
||||
assert "body" in result
|
||||
|
||||
# Test http error
|
||||
mock_service.users().drafts().send().execute.side_effect = HttpError(
|
||||
|
|
@ -226,12 +240,10 @@ async def test_get_draft_emails(mock_parse_draft_email, mock_build, mock_context
|
|||
# Test happy path
|
||||
result = await list_draft_emails(context=mock_context, n_drafts=2)
|
||||
|
||||
assert isinstance(result, str)
|
||||
result_json = json.loads(result)
|
||||
assert isinstance(result_json, dict)
|
||||
assert "emails" in result_json
|
||||
assert len(result_json["emails"]) == 1
|
||||
assert all("id" in draft and "subject" in draft for draft in result_json["emails"])
|
||||
assert isinstance(result, dict)
|
||||
assert "emails" in result
|
||||
assert len(result["emails"]) == 1
|
||||
assert all("id" in draft and "subject" in draft for draft in result["emails"])
|
||||
|
||||
# Test http error
|
||||
mock_service.users().drafts().list().execute.side_effect = HttpError(
|
||||
|
|
@ -301,12 +313,10 @@ async def test_search_emails_by_header(mock_parse_email, mock_build, mock_contex
|
|||
# Test happy path
|
||||
result = await list_emails_by_header(context=mock_context, sender="noreply@github.com", limit=2)
|
||||
|
||||
assert isinstance(result, str)
|
||||
result_json = json.loads(result)
|
||||
assert isinstance(result_json, dict)
|
||||
assert "emails" in result_json
|
||||
assert len(result_json["emails"]) == 2
|
||||
assert all("id" in email and "subject" in email for email in result_json["emails"])
|
||||
assert isinstance(result, dict)
|
||||
assert "emails" in result
|
||||
assert len(result["emails"]) == 2
|
||||
assert all("id" in email and "subject" in email for email in result["emails"])
|
||||
|
||||
# Test http error
|
||||
mock_service.users().messages().list().execute.side_effect = HttpError(
|
||||
|
|
@ -375,16 +385,13 @@ async def test_get_emails(mock_parse_email, mock_build, mock_context):
|
|||
# Test happy path
|
||||
result = await list_emails(context=mock_context, n_emails=1)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, str)
|
||||
result_json = json.loads(result)
|
||||
assert isinstance(result_json, dict)
|
||||
assert "emails" in result_json
|
||||
assert len(result_json["emails"]) == 1
|
||||
assert "id" in result_json["emails"][0]
|
||||
assert "subject" in result_json["emails"][0]
|
||||
assert "date" in result_json["emails"][0]
|
||||
assert "body" in result_json["emails"][0]
|
||||
assert isinstance(result, dict)
|
||||
assert "emails" in result
|
||||
assert len(result["emails"]) == 1
|
||||
assert "id" in result["emails"][0]
|
||||
assert "subject" in result["emails"][0]
|
||||
assert "date" in result["emails"][0]
|
||||
assert "body" in result["emails"][0]
|
||||
|
||||
# Test http error
|
||||
mock_service.users().messages().list().execute.side_effect = HttpError(
|
||||
|
|
@ -406,10 +413,11 @@ async def test_trash_email(mock_build, mock_context):
|
|||
email_id = "123456"
|
||||
result = await trash_email(context=mock_context, email_id=email_id)
|
||||
|
||||
assert (
|
||||
f"Email with ID {email_id} trashed successfully: https://mail.google.com/mail/u/0/#trash/{email_id}"
|
||||
== result
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert "thread_id" in result
|
||||
assert "subject" in result
|
||||
assert "body" in result
|
||||
|
||||
# Test http error
|
||||
mock_service.users().messages().trash().execute.side_effect = HttpError(
|
||||
|
|
@ -419,3 +427,155 @@ async def test_trash_email(mock_build, mock_context):
|
|||
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await trash_email(context=mock_context, email_id="nonexistent_email")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_google.tools.gmail.build")
|
||||
async def test_search_threads(mock_build, mock_context):
|
||||
mock_service = MagicMock()
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
# Setup mock response data
|
||||
mock_threads_list_response = {
|
||||
"threads": [
|
||||
{
|
||||
"id": "thread1",
|
||||
"snippet": "Thread snippet 1",
|
||||
},
|
||||
{
|
||||
"id": "thread2",
|
||||
"snippet": "Thread snippet 2",
|
||||
},
|
||||
],
|
||||
"nextPageToken": "next_token_123",
|
||||
"resultSizeEstimate": 2,
|
||||
}
|
||||
|
||||
# Mock the Gmail API threads().list() method
|
||||
mock_service.users().threads().list().execute.return_value = mock_threads_list_response
|
||||
|
||||
# Test happy path
|
||||
result = await search_threads(
|
||||
context=mock_context,
|
||||
sender="test@example.com",
|
||||
max_results=2,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "threads" in result
|
||||
assert len(result["threads"]) == 2
|
||||
assert result["threads"][0]["id"] == "thread1"
|
||||
assert "next_page_token" in result
|
||||
|
||||
# Test error handling
|
||||
mock_service.users().threads().list().execute.side_effect = HttpError(
|
||||
resp=MagicMock(status=400),
|
||||
content=b'{"error": {"message": "Invalid request"}}',
|
||||
)
|
||||
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await search_threads(
|
||||
context=mock_context,
|
||||
sender="test@example.com",
|
||||
max_results=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_google.tools.gmail.build")
|
||||
async def test_list_threads(mock_build, mock_context):
|
||||
mock_service = MagicMock()
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
# Setup mock response data
|
||||
mock_threads_list_response = {
|
||||
"threads": [
|
||||
{
|
||||
"id": "thread1",
|
||||
"snippet": "Thread snippet 1",
|
||||
},
|
||||
{
|
||||
"id": "thread2",
|
||||
"snippet": "Thread snippet 2",
|
||||
},
|
||||
],
|
||||
"nextPageToken": "next_token_123",
|
||||
"resultSizeEstimate": 2,
|
||||
}
|
||||
|
||||
# Mock the Gmail API threads().list() method
|
||||
mock_service.users().threads().list().execute.return_value = mock_threads_list_response
|
||||
|
||||
# Test happy path
|
||||
result = await list_threads(
|
||||
context=mock_context,
|
||||
max_results=2,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "threads" in result
|
||||
assert len(result["threads"]) == 2
|
||||
assert result["threads"][0]["id"] == "thread1"
|
||||
assert "next_page_token" in result
|
||||
|
||||
# Test error handling
|
||||
mock_service.users().threads().list().execute.side_effect = HttpError(
|
||||
resp=MagicMock(status=400),
|
||||
content=b'{"error": {"message": "Invalid request"}}',
|
||||
)
|
||||
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await list_threads(
|
||||
context=mock_context,
|
||||
max_results=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_google.tools.gmail.build")
|
||||
async def test_get_thread(mock_build, mock_context):
|
||||
mock_service = MagicMock()
|
||||
mock_build.return_value = mock_service
|
||||
|
||||
# Setup mock response data
|
||||
mock_thread_get_response = {
|
||||
"id": "thread1",
|
||||
"messages": [
|
||||
{
|
||||
"id": "message1",
|
||||
"snippet": "Message snippet 1",
|
||||
},
|
||||
{
|
||||
"id": "message2",
|
||||
"snippet": "Message snippet 2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Mock the Gmail API threads().get() method
|
||||
mock_service.users().threads().get().execute.return_value = mock_thread_get_response
|
||||
|
||||
# Test happy path
|
||||
result = await get_thread(
|
||||
context=mock_context,
|
||||
thread_id="thread1",
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "id" in result
|
||||
assert result["id"] == "thread1"
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][0]["id"] == "message1"
|
||||
|
||||
# Test error handling
|
||||
mock_service.users().threads().get().execute.side_effect = HttpError(
|
||||
resp=MagicMock(status=404),
|
||||
content=b'{"error": {"message": "Thread not found"}}',
|
||||
)
|
||||
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await get_thread(
|
||||
context=mock_context,
|
||||
thread_id="invalid_thread",
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue