Fix ruff (#64)

On the last few PRs I have noticed two problems:
1. `ruff format` fails even though it seems OK on our local machines
(sometimes, not always)
2. Nate's and Sam's machines kept flip-flopping a specific piece of
formatting back and forth, indicating a subtle difference of config
hiding somewhere
3. This was reproducible by running `ruff format` in the terminal,
followed by `make check`. The former would edit files, and then `make
check` would edit them back!

This PR addresses both issues, and further standardizes our editor &
linter configs to be super stable.
Specifically:
1. The main fix for the above, the pre-commit hook was pinned to a super
old version of ruff.
This resulted in subtle differences in behavior between our machines,
and on CI.

2. Moved ruff settings from `pyproject.toml` to `.ruff.toml`
pyproject files in subdirectories (e.g. `toolkits/**`) were overriding
the main pyproject file and erasing the custom ruff config we set at the
root. This meant that our ruff config was applied to `arcade` but not to
any of the other packages.
By moving the config to `.ruff.toml` at the root, all projects will
inherit the same ruff linting & formatting config.

4. Un-ignored the `.vscode/` directory so that we can share
vscode/cursor workspace settings.
This is valuable for standardizing settings like the default formatter
(ruff) and default test framework (pytest).
However, it's important that going forward we _only_ commit things here
that should apply across all of our machines.

5. To avoid any conflict between prettier and ruff, prettier now
explicitly ignores *.py files

6. Finally, `ruff format` and `make check` agree. A number of files are
newly auto-formatted.
This commit is contained in:
Nate Barbettini 2024-09-25 09:47:30 -07:00 committed by GitHub
parent 33621a79e4
commit 894fa878f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 211 additions and 256 deletions

View file

@ -7,7 +7,7 @@ insert_final_newline = true
end_of_line = lf end_of_line = lf
indent_style = space indent_style = space
indent_size = 4 indent_size = 4
max_line_length = 120 max_line_length = 100 # This is also set in .ruff.toml for ruff
[*.{json,jsonc,yml,yaml}] [*.{json,jsonc,yml,yaml}]
indent_style = space indent_style = space

4
.gitignore vendored
View file

@ -167,10 +167,6 @@ dmypy.json
# Cython debug symbols # Cython debug symbols
cython_debug/ cython_debug/
# Vscode config files
.vscode/
!.vscode/launch.json # Exception: allow launch.json
# PyCharm # PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore

View file

@ -10,7 +10,7 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.6" rev: v0.6.7
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]

2
.prettierignore Normal file
View file

@ -0,0 +1,2 @@
# Ignore Python files for Prettier
*.py

View file

@ -1,4 +1,8 @@
# See https://prettier.io/docs/en/configuration # See https://prettier.io/docs/en/configuration
# Note: This prettier config is only for the non-python files in this repo.
# Python files are formatted with ruff and ignored in .prettierignore
trailingComma = "es5" trailingComma = "es5"
tabWidth = 4 tabWidth = 4
semi = false semi = false

63
.ruff.toml Normal file
View file

@ -0,0 +1,63 @@
target-version = "py39"
line-length = 100
fix = true
[lint]
select = [
# flake8-2020
"YTT",
# flake8-bandit
"S",
# flake8-bugbear
"B",
# flake8-builtins
"A",
# flake8-comprehensions
"C4",
# flake8-debugger
"T10",
# flake8-simplify
"SIM",
# isort
"I",
# mccabe
"C90",
# pycodestyle
"E", "W",
# pyflakes
"F",
# pygrep-hooks
"PGH",
# pyupgrade
"UP",
# ruff
"RUF",
# tryceratops
"TRY",
]
# TODO work to remove these
ignore = [
# LineTooLong
"E501",
# DoNotAssignLambda
"E731",
# raise from (cli specific)
"B904", # Previously "TRY200"
# Depends function in arg string
"B008",
# raise from (cli specific)
"B904",
# long message exceptions
"TRY003",
# subprocess.Popen
"S603",
]
[lint.per-file-ignores]
"**/tests/*" = ["S101"]
"toolkits/*" = ["A002", "TRY300", "C901", "C416", "S113", "RUF013", "SIM103"] # TODO: Remove everything here
[format]
preview = true
skip-magic-trailing-comma = false

3
.vscode/extensions.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"recommendations": ["charliermarsh.ruff", "esbenp.prettier-vscode"]
}

21
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,21 @@
{
"files.exclude": {
"**/__pycache__": true,
"**/.mypy_cache": true,
"**/.pytest_cache": true,
"**/.ruff_cache": true
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff"
},
"[jsonc]": {
"editor.defaultFormatter": "esbenp.prettier-vscode",
"editor.formatOnSave": true
},
"[json]": {
"editor.defaultFormatter": "esbenp.prettier-vscode",
"editor.formatOnSave": true
}
}

View file

@ -136,13 +136,11 @@ class FullyQualifiedName:
) )
def __hash__(self) -> int: def __hash__(self) -> int:
return hash( return hash((
( self.name.lower(),
self.name.lower(), self.toolkit_name.lower(),
self.toolkit_name.lower(), (self.toolkit_version or "").lower(),
(self.toolkit_version or "").lower(), ))
)
)
def equals_ignoring_version(self, other: "FullyQualifiedName") -> bool: def equals_ignoring_version(self, other: "FullyQualifiedName") -> bool:
"""Check if two fully-qualified tool names are equal, ignoring the version.""" """Check if two fully-qualified tool names are equal, ignoring the version."""

View file

@ -103,15 +103,13 @@ class EvaluationResult:
expected: The expected value for the critic. expected: The expected value for the critic.
actual: The actual value for the critic. actual: The actual value for the critic.
""" """
self.results.append( self.results.append({
{ "field": field,
"field": field, **result,
**result, "weight": weight,
"weight": weight, "expected": expected,
"expected": expected, "actual": actual,
"actual": actual, })
}
)
def score_tool_selection(self, expected: str, actual: str, weight: float) -> float: def score_tool_selection(self, expected: str, actual: str, weight: float) -> float:
""" """
@ -658,12 +656,10 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
if message.tool_calls: if message.tool_calls:
for tool_call in message.tool_calls: for tool_call in message.tool_calls:
tool_args_list.append( tool_args_list.append((
( tool_call.function.name,
tool_call.function.name, json.loads(tool_call.function.arguments),
json.loads(tool_call.function.arguments), ))
)
)
return tool_args_list return tool_args_list

View file

@ -60,69 +60,9 @@ ignore_missing_imports = "True"
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]
[tool.ruff]
target-version = "py39"
line-length = 100
fix = true
select = [
# flake8-2020
"YTT",
# flake8-bandit
"S",
# flake8-bugbear
"B",
# flake8-builtins
"A",
# flake8-comprehensions
"C4",
# flake8-debugger
"T10",
# flake8-simplify
"SIM",
# isort
"I",
# mccabe
"C90",
# pycodestyle
"E", "W",
# pyflakes
"F",
# pygrep-hooks
"PGH",
# pyupgrade
"UP",
# ruff
"RUF",
# tryceratops
"TRY",
]
ignore = [ # TODO work to remove these
# LineTooLong
"E501",
# DoNotAssignLambda
"E731",
# raise from (cli specific)
"TRY200",
# Depends function in arg string
"B008",
# raise from (cli specific)
"B904",
# long message exceptions
"TRY003",
# subprocess.Popen
"S603",
]
[tool.ruff.format]
preview = true
[tool.coverage.report] [tool.coverage.report]
skip_empty = true skip_empty = true
[tool.coverage.run] [tool.coverage.run]
branch = true branch = true
source = ["arcade"] source = ["arcade"]
[tool.ruff.per-file-ignores]
"tests/*" = ["S101"]

View file

@ -30,12 +30,10 @@ test_cases = [
# Generate tool functions dynamically # Generate tool functions dynamically
def generate_tool_function(input_types: list[type], output_type: type | None): def generate_tool_function(input_types: list[type], output_type: type | None):
input_annotation = ", ".join( input_annotation = ", ".join([
[ f"param{i}: Annotated[{input_type.__name__}, 'Param {i + 1}']"
f"param{i}: Annotated[{input_type.__name__}, 'Param {i + 1}']" for i, input_type in enumerate(input_types)
for i, input_type in enumerate(input_types) ])
]
)
output_annotation = f" -> {output_type.__name__}" if output_type else "" output_annotation = f" -> {output_type.__name__}" if output_type else ""
func_code = f""" func_code = f"""

View file

@ -1,5 +1,5 @@
import os import os
from arcade.core.toolkit import Toolkit
import arcade_math import arcade_math
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
@ -7,6 +7,7 @@ from pydantic import BaseModel
from arcade.actor.fastapi.actor import FastAPIActor from arcade.actor.fastapi.actor import FastAPIActor
from arcade.client import AsyncArcade from arcade.client import AsyncArcade
from arcade.core.config import config from arcade.core.config import config
from arcade.core.toolkit import Toolkit
if not config.api or not config.api.key: if not config.api or not config.api.key:
raise ValueError("Arcade API key not set. Please run `arcade login`.") raise ValueError("Arcade API key not set. Please run `arcade login`.")
@ -47,6 +48,8 @@ async def postChat(request: ChatRequest, tool_choice: str = "execute"):
tool_choice=tool_choice, tool_choice=tool_choice,
user=config.user.email if config.user else None, user=config.user.email if config.user else None,
) )
return raw_response.choices
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
else:
return raw_response.choices

View file

@ -26,9 +26,7 @@ auth_response = client.auth.authorize(
# If authorization is not completed, prompt the user and poll for status # If authorization is not completed, prompt the user and poll for status
if auth_response.status != "completed": if auth_response.status != "completed":
print( print("Please complete the authorization challenge in your browser before continuing:")
"Please complete the authorization challenge in your browser before continuing:"
)
print(auth_response.auth_url) print(auth_response.auth_url)
input("Press Enter to continue...") input("Press Enter to continue...")

View file

@ -3,23 +3,23 @@ import json
from email.message import EmailMessage from email.message import EmailMessage
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import Annotated, Optional from typing import Annotated, Optional
from arcade.core.errors import ToolExecutionError, ToolInputError
from googleapiclient.errors import HttpError
from arcade_google.tools.utils import (
DateRange,
parse_email,
get_draft_url,
get_sent_email_url,
get_email_in_trash_url,
parse_draft_email,
)
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from arcade.core.errors import ToolExecutionError, ToolInputError
from arcade.core.schema import ToolContext from arcade.core.schema import ToolContext
from arcade.sdk import tool from arcade.sdk import tool
from arcade.sdk.auth import Google from arcade.sdk.auth import Google
from arcade_google.tools.utils import (
DateRange,
get_draft_url,
get_email_in_trash_url,
get_sent_email_url,
parse_draft_email,
parse_email,
)
# Email sending tools # Email sending tools
@ -42,9 +42,7 @@ async def send_email(
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
message = EmailMessage() message = EmailMessage()
message.set_content(body) message.set_content(body)
@ -62,9 +60,7 @@ async def send_email(
email = {"raw": encoded_message} email = {"raw": encoded_message}
# Send the email # Send the email
sent_message = ( sent_message = service.users().messages().send(userId="me", body=email).execute()
service.users().messages().send(userId="me", body=email).execute()
)
return f"Email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}" return f"Email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
except HttpError as e: except HttpError as e:
raise ToolExecutionError( raise ToolExecutionError(
@ -91,14 +87,10 @@ async def send_draft_email(
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
# Send the draft email # Send the draft email
sent_message = ( sent_message = service.users().drafts().send(userId="me", body={"id": id}).execute()
service.users().drafts().send(userId="me", body={"id": id}).execute()
)
# Construct the URL to the sent email # Construct the URL to the sent email
return f"Draft email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}" return f"Draft email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
@ -133,9 +125,7 @@ async def write_draft_email(
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
message = MIMEText(body) message = MIMEText(body)
message["to"] = recipient message["to"] = recipient
@ -151,9 +141,7 @@ async def write_draft_email(
# Create the draft # Create the draft
draft = {"message": {"raw": raw_message}} draft = {"message": {"raw": raw_message}}
draft_message = ( draft_message = service.users().drafts().create(userId="me", body=draft).execute()
service.users().drafts().create(userId="me", body=draft).execute()
)
return f"Draft email with ID {draft_message['id']} created: {get_draft_url(draft_message['id'])}" return f"Draft email with ID {draft_message['id']} created: {get_draft_url(draft_message['id'])}"
except HttpError as e: except HttpError as e:
raise ToolExecutionError( raise ToolExecutionError(
@ -187,9 +175,7 @@ async def update_draft_email(
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
message = MIMEText(body) message = MIMEText(body)
message["to"] = recipient message["to"] = recipient
@ -236,9 +222,7 @@ async def delete_draft_email(
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
# Delete the draft # Delete the draft
service.users().drafts().delete(userId="me", id=id).execute() service.users().drafts().delete(userId="me", id=id).execute()
@ -270,9 +254,7 @@ async def trash_email(
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
# Trash the email # Trash the email
service.users().messages().trash(userId="me", id=id).execute() service.users().messages().trash(userId="me", id=id).execute()
@ -298,33 +280,25 @@ async def trash_email(
async def list_draft_emails( async def list_draft_emails(
context: ToolContext, context: ToolContext,
n_drafts: Annotated[int, "Number of draft emails to read"] = 5, n_drafts: Annotated[int, "Number of draft emails to read"] = 5,
) -> Annotated[ ) -> Annotated[str, "A JSON string containing a list of draft email details and their IDs"]:
str, "A JSON string containing a list of draft email details and their IDs"
]:
""" """
Lists draft emails in the user's draft mailbox using the Gmail API. Lists draft emails in the user's draft mailbox using the Gmail API.
""" """
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
listed_drafts = service.users().drafts().list(userId="me").execute() listed_drafts = service.users().drafts().list(userId="me").execute()
if not listed_drafts: if not listed_drafts:
return {"emails": []} return {"emails": []}
draft_ids = [draft["id"] for draft in listed_drafts.get("drafts", [])][ draft_ids = [draft["id"] for draft in listed_drafts.get("drafts", [])][:n_drafts]
:n_drafts
]
emails = [] emails = []
for draft_id in draft_ids: for draft_id in draft_ids:
try: try:
draft_data = ( draft_data = service.users().drafts().get(userId="me", id=draft_id).execute()
service.users().drafts().get(userId="me", id=draft_id).execute()
)
draft_details = parse_draft_email(draft_data) draft_details = parse_draft_email(draft_data)
if draft_details: if draft_details:
emails.append(draft_details) emails.append(draft_details)
@ -352,15 +326,9 @@ async def list_draft_emails(
) )
async def list_emails_by_header( async def list_emails_by_header(
context: ToolContext, context: ToolContext,
sender: Annotated[ sender: Annotated[Optional[str], "The name or email address of the sender of the email"] = None,
Optional[str], "The name or email address of the sender of the email" recipient: Annotated[Optional[str], "The name or email address of the recipient"] = None,
] = None, subject: Annotated[Optional[str], "Words to find in the subject of the email"] = None,
recipient: Annotated[
Optional[str], "The name or email address of the recipient"
] = None,
subject: Annotated[
Optional[str], "Words to find in the subject of the email"
] = None,
body: Annotated[Optional[str], "Words to find in the body of the email"] = None, body: Annotated[Optional[str], "Words to find in the body of the email"] = None,
date_range: Annotated[Optional[DateRange], "The date range of the email"] = None, date_range: Annotated[Optional[DateRange], "The date range of the email"] = None,
limit: Annotated[Optional[int], "The maximum number of emails to return"] = 25, limit: Annotated[Optional[int], "The maximum number of emails to return"] = 25,
@ -394,9 +362,7 @@ async def list_emails_by_header(
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
# Perform the search # Perform the search
response = ( response = (
@ -413,9 +379,7 @@ async def list_emails_by_header(
emails = [] emails = []
for msg in messages: for msg in messages:
try: try:
email_data = ( email_data = service.users().messages().get(userId="me", id=msg["id"]).execute()
service.users().messages().get(userId="me", id=msg["id"]).execute()
)
email_details = parse_email(email_data) email_details = parse_email(email_data)
if email_details: if email_details:
emails.append(email_details) emails.append(email_details)
@ -449,13 +413,9 @@ async def list_emails(
""" """
try: try:
# Set up the Gmail API client # Set up the Gmail API client
service = build( service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
"gmail", "v1", credentials=Credentials(context.authorization.token)
)
messages = ( messages = service.users().messages().list(userId="me").execute().get("messages", [])
service.users().messages().list(userId="me").execute().get("messages", [])
)
if not messages: if not messages:
return {"emails": []} return {"emails": []}
@ -463,9 +423,7 @@ async def list_emails(
emails = [] emails = []
for msg in messages[:n_emails]: for msg in messages[:n_emails]:
try: try:
email_data = ( email_data = service.users().messages().get(userId="me", id=msg["id"]).execute()
service.users().messages().get(userId="me", id=msg["id"]).execute()
)
email_details = parse_email(email_data) email_details = parse_email(email_data)
if email_details: if email_details:
emails.append(email_details) emails.append(email_details)

View file

@ -1,8 +1,8 @@
from base64 import urlsafe_b64decode
import datetime import datetime
from enum import Enum
import re import re
from typing import Any, Dict, Optional from base64 import urlsafe_b64decode
from enum import Enum
from typing import Any, Optional, dict
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -30,20 +30,18 @@ class DateRange(Enum):
elif self == DateRange.THIS_MONTH: elif self == DateRange.THIS_MONTH:
comparison_date = today.replace(day=1) comparison_date = today.replace(day=1)
elif self == DateRange.LAST_MONTH: elif self == DateRange.LAST_MONTH:
comparison_date = ( comparison_date = (today.replace(day=1) - datetime.timedelta(days=1)).replace(day=1)
today.replace(day=1) - datetime.timedelta(days=1)
).replace(day=1)
elif self == DateRange.THIS_YEAR: elif self == DateRange.THIS_YEAR:
comparison_date = today.replace(month=1, day=1) comparison_date = today.replace(month=1, day=1)
elif self == DateRange.LAST_MONTH: elif self == DateRange.LAST_MONTH:
comparison_date = ( comparison_date = (today.replace(month=1, day=1) - datetime.timedelta(days=1)).replace(
today.replace(month=1, day=1) - datetime.timedelta(days=1) month=1, day=1
).replace(month=1, day=1) )
return result + comparison_date.strftime("%Y/%m/%d") return result + comparison_date.strftime("%Y/%m/%d")
def parse_email(email_data: Dict[str, Any]) -> Optional[Dict[str, str]]: def parse_email(email_data: dict[str, Any]) -> Optional[dict[str, str]]:
""" """
Parse email data and extract relevant information. Parse email data and extract relevant information.
@ -71,7 +69,7 @@ def parse_email(email_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
return None return None
def parse_draft_email(draft_email_data: Dict[str, Any]) -> Optional[Dict[str, str]]: def parse_draft_email(draft_email_data: dict[str, Any]) -> Optional[dict[str, str]]:
""" """
Parse draft email data and extract relevant information. Parse draft email data and extract relevant information.
@ -112,7 +110,7 @@ def get_email_in_trash_url(email_id):
return f"https://mail.google.com/mail/u/0/#trash/{email_id}" return f"https://mail.google.com/mail/u/0/#trash/{email_id}"
def _get_email_body(payload: Dict[str, Any]) -> Optional[str]: def _get_email_body(payload: dict[str, Any]) -> Optional[str]:
""" """
Extract email body from payload. Extract email body from payload.

View file

@ -1,27 +1,28 @@
import json import json
from arcade.core.errors import ToolExecutionError from unittest.mock import MagicMock, patch
from arcade_google.tools.utils import parse_draft_email, parse_email
import pytest import pytest
from unittest.mock import patch, MagicMock
from arcade_google.tools.gmail import ( from arcade_google.tools.gmail import (
send_email,
write_draft_email,
update_draft_email,
send_draft_email,
delete_draft_email, delete_draft_email,
list_draft_emails, list_draft_emails,
list_emails_by_header,
list_emails, list_emails,
list_emails_by_header,
send_draft_email,
send_email,
trash_email, trash_email,
update_draft_email,
write_draft_email,
) )
from arcade_google.tools.utils import parse_draft_email, parse_email
from arcade.core.schema import ToolContext, ToolAuthorizationContext
from googleapiclient.errors import HttpError from googleapiclient.errors import HttpError
from arcade.core.errors import ToolExecutionError
from arcade.core.schema import ToolAuthorizationContext, ToolContext
@pytest.fixture @pytest.fixture
def mock_context(): def mock_context():
mock_auth = ToolAuthorizationContext(token="fake-token") mock_auth = ToolAuthorizationContext(token="fake-token") # noqa: S106
return ToolContext(authorization=mock_auth) return ToolContext(authorization=mock_auth)
@ -214,9 +215,7 @@ async def test_get_draft_emails(mock_parse_draft_email, mock_build, mock_context
mock_build.return_value = mock_service mock_build.return_value = mock_service
# Mock the response from the Gmail list drafts API # Mock the response from the Gmail list drafts API
mock_service.users().drafts().list().execute.return_value = ( mock_service.users().drafts().list().execute.return_value = mock_drafts_list_response
mock_drafts_list_response
)
# Mock the response from the Gmail get drafts API # Mock the response from the Gmail get drafts API
mock_service.users().drafts().get().execute.return_value = mock_drafts_get_response mock_service.users().drafts().get().execute.return_value = mock_drafts_get_response
@ -291,22 +290,16 @@ async def test_search_emails_by_header(mock_parse_email, mock_build, mock_contex
mock_build.return_value = mock_service mock_build.return_value = mock_service
# Mock the response from the Gmail list messages API # Mock the response from the Gmail list messages API
mock_service.users().messages().list().execute.return_value = ( mock_service.users().messages().list().execute.return_value = mock_messages_list_response
mock_messages_list_response
)
# Mock the response from the Gmail get messages API # Mock the response from the Gmail get messages API
mock_service.users().messages().get().execute.return_value = ( mock_service.users().messages().get().execute.return_value = mock_messages_get_response
mock_messages_get_response
)
# Mock the parse_email function since parse_email doesn't accept object of type MagicMock # Mock the parse_email function since parse_email doesn't accept object of type MagicMock
mock_parse_email.return_value = parse_email(mock_messages_get_response) mock_parse_email.return_value = parse_email(mock_messages_get_response)
# Test happy path # Test happy path
result = await list_emails_by_header( result = await list_emails_by_header(context=mock_context, sender="noreply@github.com", limit=2)
context=mock_context, sender="noreply@github.com", limit=2
)
assert isinstance(result, str) assert isinstance(result, str)
result_json = json.loads(result) result_json = json.loads(result)
@ -322,9 +315,7 @@ async def test_search_emails_by_header(mock_parse_email, mock_build, mock_contex
) )
with pytest.raises(ToolExecutionError): with pytest.raises(ToolExecutionError):
await list_emails_by_header( await list_emails_by_header(context=mock_context, sender="noreply@github.com", limit=2)
context=mock_context, sender="noreply@github.com", limit=2
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -373,14 +364,10 @@ async def test_get_emails(mock_parse_email, mock_build, mock_context):
mock_build.return_value = mock_service mock_build.return_value = mock_service
# Mock the response from the Gmail list messages API # Mock the response from the Gmail list messages API
mock_service.users().messages().list().execute.return_value = ( mock_service.users().messages().list().execute.return_value = mock_messages_list_response
mock_messages_list_response
)
# Mock the Gmail get messages API # Mock the Gmail get messages API
mock_service.users().messages().get().execute.return_value = ( mock_service.users().messages().get().execute.return_value = mock_messages_get_response
mock_messages_get_response
)
# Mock the parse_email function since parse_email doesn't accept object of type MagicMock # Mock the parse_email function since parse_email doesn't accept object of type MagicMock
mock_parse_email.return_value = parse_email(mock_messages_get_response) mock_parse_email.return_value = parse_email(mock_messages_get_response)

View file

@ -43,9 +43,7 @@ def math_eval_suite():
], ],
rubric=rubric, rubric=rubric,
critics=[ critics=[
BinaryCritic( BinaryCritic(critic_field="a", weight=0.5), # TODO: weight should be optional
critic_field="a", weight=0.5
), # TODO: weight should be optional
BinaryCritic(critic_field="b", weight=0.5), BinaryCritic(critic_field="b", weight=0.5),
], ],
) )

View file

@ -1,10 +1,10 @@
import pytest import pytest
from arcade_math.tools.arithmetic import ( from arcade_math.tools.arithmetic import (
add, add,
subtract,
multiply,
divide, divide,
multiply,
sqrt, sqrt,
subtract,
sum_list, sum_list,
sum_range, sum_range,
) )

View file

@ -1,7 +1,9 @@
import json import json
import os import os
import serpapi
from typing import Annotated, Any, Optional from typing import Annotated, Any, Optional
import serpapi
from arcade.sdk import tool from arcade.sdk import tool

View file

@ -165,9 +165,7 @@ def slack_eval_suite() -> EvalSuite:
], ],
critics=[ critics=[
SimilarityCritic(critic_field="user_name", weight=0.7), SimilarityCritic(critic_field="user_name", weight=0.7),
SimilarityCritic( SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.6),
critic_field="message", weight=0.3, similarity_threshold=0.6
),
], ],
) )

View file

@ -1,11 +1,11 @@
from typing import Annotated from typing import Annotated
from arcade.core.errors import ToolExecutionError
from arcade.sdk.auth import X
import requests import requests
from arcade.sdk import tool
from arcade.core.errors import ToolExecutionError
from arcade.core.schema import ToolContext from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import X
from arcade_x.tools.utils import get_tweet_url, parse_search_recent_tweets_response from arcade_x.tools.utils import get_tweet_url, parse_search_recent_tweets_response
TWEETS_URL = "https://api.x.com/2/tweets" TWEETS_URL = "https://api.x.com/2/tweets"
@ -33,9 +33,7 @@ def post_tweet(
) )
tweet_id = response.json()["data"]["id"] tweet_id = response.json()["data"]["id"]
return ( return f"Tweet with id {tweet_id} posted successfully. URL: {get_tweet_url(tweet_id)}"
f"Tweet with id {tweet_id} posted successfully. URL: {get_tweet_url(tweet_id)}"
)
@tool(requires_auth=X(scopes=["tweet.read", "tweet.write", "users.read"])) @tool(requires_auth=X(scopes=["tweet.read", "tweet.write", "users.read"]))
@ -74,11 +72,11 @@ def search_recent_tweets_by_username(
} }
params = { params = {
"query": f"from:{username}", "query": f"from:{username}",
"max_results": max( "max_results": max(max_results, 10), # X API does not allow 'max_results' less than 10
max_results, 10
), # X API does not allow 'max_results' less than 10
} }
url = "https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username" url = (
"https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username"
)
response = requests.get(url, headers=headers, params=params) response = requests.get(url, headers=headers, params=params)
@ -95,12 +93,8 @@ def search_recent_tweets_by_username(
@tool(requires_auth=X(scopes=["tweet.read", "users.read"])) @tool(requires_auth=X(scopes=["tweet.read", "users.read"]))
def search_recent_tweets_by_keywords( def search_recent_tweets_by_keywords(
context: ToolContext, context: ToolContext,
keywords: Annotated[ keywords: Annotated[list[str], "List of keywords that must be present in the tweet"] = None,
list[str], "List of keywords that must be present in the tweet" phrases: Annotated[list[str], "List of phrases that must be present in the tweet"] = None,
] = None,
phrases: Annotated[
list[str], "List of phrases that must be present in the tweet"
] = None,
max_results: Annotated[ max_results: Annotated[
int, "The maximum number of results to return. Cannot be less than 10" int, "The maximum number of results to return. Cannot be less than 10"
] = 10, ] = 10,
@ -122,11 +116,11 @@ def search_recent_tweets_by_keywords(
query = " ".join([f'"{phrase}"' for phrase in phrases]) + " ".join(keywords) query = " ".join([f'"{phrase}"' for phrase in phrases]) + " ".join(keywords)
params = { params = {
"query": query, "query": query,
"max_results": max( "max_results": max(max_results, 10), # X API does not allow 'max_results' less than 10
max_results, 10
), # X API does not allow 'max_results' less than 10
} }
url = "https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username" url = (
"https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username"
)
response = requests.get(url, headers=headers, params=params) response = requests.get(url, headers=headers, params=params)

View file

@ -1,10 +1,11 @@
from typing import Annotated from typing import Annotated
from arcade.core.errors import ToolExecutionError
from arcade.sdk.auth import X
import requests
from arcade.sdk import tool
import requests
from arcade.core.errors import ToolExecutionError
from arcade.core.schema import ToolContext from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import X
# Users Lookup Tools. See developer docs for additional available query parameters: https://developer.x.com/en/docs/x-api/users/lookup/api-reference # Users Lookup Tools. See developer docs for additional available query parameters: https://developer.x.com/en/docs/x-api/users/lookup/api-reference

View file

@ -1,4 +1,5 @@
import json import json
from requests import Response from requests import Response
@ -38,9 +39,7 @@ def parse_search_recent_tweets_response(response: Response) -> str:
for tweet in tweets_data["data"]: for tweet in tweets_data["data"]:
tweet["tweet_url"] = get_tweet_url(tweet["id"]) tweet["tweet_url"] = get_tweet_url(tweet["id"])
for tweet_data, user_data in zip( for tweet_data, user_data in zip(tweets_data["data"], tweets_data["includes"]["users"]):
tweets_data["data"], tweets_data["includes"]["users"]
):
tweet_data["author_username"] = user_data["username"] tweet_data["author_username"] = user_data["username"]
tweet_data["author_name"] = user_data["name"] tweet_data["author_name"] = user_data["name"]

View file

@ -43,9 +43,7 @@ def x_eval_suite() -> EvalSuite:
expected_tool_calls=[ expected_tool_calls=[
( (
post_tweet, post_tweet,
{ {"tweet_text": "Hello World! Exciting stuff is happening over at Arcade AI!"},
"tweet_text": "Hello World! Exciting stuff is happening over at Arcade AI!"
},
) )
], ],
critics=[ critics=[