Fix ruff (#64)
On the last few PRs I have noticed two problems: 1. `ruff format` fails even though it seems OK on our local machines (sometimes, not always) 2. Nate's and Sam's machines kept flip-flopping a specific piece of formatting back and forth, indicating a subtle difference of config hiding somewhere 3. This was reproducible by running `ruff format` in the terminal, followed by `make check`. The former would edit files, and then `make check` would edit them back! This PR addresses both issues, and further standardizes our editor & linter configs to be super stable. Specifically: 1. The main fix for the above, the pre-commit hook was pinned to a super old version of ruff. This resulted in subtle differences in behavior between our machines, and on CI. 2. Moved ruff settings from `pyproject.toml` to `.ruff.toml` pyproject files in subdirectories (e.g. `toolkits/**`) were overriding the main pyproject file and erasing the custom ruff config we set at the root. This meant that our ruff config was applied to `arcade` but not to any of the other packages. By moving the config to `.ruff.toml` at the root, all projects will inherit the same ruff linting & formatting config. 4. Un-ignored the `.vscode/` directory so that we can share vscode/cursor workspace settings. This is valuable for standardizing settings like the default formatter (ruff) and default test framework (pytest). However, it's important that going forward we _only_ commit things here that should apply across all of our machines. 5. To avoid any conflict between prettier and ruff, prettier now explicitly ignores *.py files 6. Finally, `ruff format` and `make check` agree. A number of files are newly auto-formatted.
This commit is contained in:
parent
33621a79e4
commit
894fa878f1
25 changed files with 211 additions and 256 deletions
|
|
@ -7,7 +7,7 @@ insert_final_newline = true
|
||||||
end_of_line = lf
|
end_of_line = lf
|
||||||
indent_style = space
|
indent_style = space
|
||||||
indent_size = 4
|
indent_size = 4
|
||||||
max_line_length = 120
|
max_line_length = 100 # This is also set in .ruff.toml for ruff
|
||||||
|
|
||||||
[*.{json,jsonc,yml,yaml}]
|
[*.{json,jsonc,yml,yaml}]
|
||||||
indent_style = space
|
indent_style = space
|
||||||
|
|
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -167,10 +167,6 @@ dmypy.json
|
||||||
# Cython debug symbols
|
# Cython debug symbols
|
||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
# Vscode config files
|
|
||||||
.vscode/
|
|
||||||
!.vscode/launch.json # Exception: allow launch.json
|
|
||||||
|
|
||||||
# PyCharm
|
# PyCharm
|
||||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ repos:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: "v0.1.6"
|
rev: v0.6.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
|
|
|
||||||
2
.prettierignore
Normal file
2
.prettierignore
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
# Ignore Python files for Prettier
|
||||||
|
*.py
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
# See https://prettier.io/docs/en/configuration
|
# See https://prettier.io/docs/en/configuration
|
||||||
|
|
||||||
|
# Note: This prettier config is only for the non-python files in this repo.
|
||||||
|
# Python files are formatted with ruff and ignored in .prettierignore
|
||||||
|
|
||||||
trailingComma = "es5"
|
trailingComma = "es5"
|
||||||
tabWidth = 4
|
tabWidth = 4
|
||||||
semi = false
|
semi = false
|
||||||
|
|
|
||||||
63
.ruff.toml
Normal file
63
.ruff.toml
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
target-version = "py39"
|
||||||
|
line-length = 100
|
||||||
|
fix = true
|
||||||
|
|
||||||
|
[lint]
|
||||||
|
select = [
|
||||||
|
# flake8-2020
|
||||||
|
"YTT",
|
||||||
|
# flake8-bandit
|
||||||
|
"S",
|
||||||
|
# flake8-bugbear
|
||||||
|
"B",
|
||||||
|
# flake8-builtins
|
||||||
|
"A",
|
||||||
|
# flake8-comprehensions
|
||||||
|
"C4",
|
||||||
|
# flake8-debugger
|
||||||
|
"T10",
|
||||||
|
# flake8-simplify
|
||||||
|
"SIM",
|
||||||
|
# isort
|
||||||
|
"I",
|
||||||
|
# mccabe
|
||||||
|
"C90",
|
||||||
|
# pycodestyle
|
||||||
|
"E", "W",
|
||||||
|
# pyflakes
|
||||||
|
"F",
|
||||||
|
# pygrep-hooks
|
||||||
|
"PGH",
|
||||||
|
# pyupgrade
|
||||||
|
"UP",
|
||||||
|
# ruff
|
||||||
|
"RUF",
|
||||||
|
# tryceratops
|
||||||
|
"TRY",
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO work to remove these
|
||||||
|
ignore = [
|
||||||
|
# LineTooLong
|
||||||
|
"E501",
|
||||||
|
# DoNotAssignLambda
|
||||||
|
"E731",
|
||||||
|
# raise from (cli specific)
|
||||||
|
"B904", # Previously "TRY200"
|
||||||
|
# Depends function in arg string
|
||||||
|
"B008",
|
||||||
|
# raise from (cli specific)
|
||||||
|
"B904",
|
||||||
|
# long message exceptions
|
||||||
|
"TRY003",
|
||||||
|
# subprocess.Popen
|
||||||
|
"S603",
|
||||||
|
]
|
||||||
|
|
||||||
|
[lint.per-file-ignores]
|
||||||
|
"**/tests/*" = ["S101"]
|
||||||
|
"toolkits/*" = ["A002", "TRY300", "C901", "C416", "S113", "RUF013", "SIM103"] # TODO: Remove everything here
|
||||||
|
|
||||||
|
[format]
|
||||||
|
preview = true
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
3
.vscode/extensions.json
vendored
Normal file
3
.vscode/extensions.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"recommendations": ["charliermarsh.ruff", "esbenp.prettier-vscode"]
|
||||||
|
}
|
||||||
21
.vscode/settings.json
vendored
Normal file
21
.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
{
|
||||||
|
"files.exclude": {
|
||||||
|
"**/__pycache__": true,
|
||||||
|
"**/.mypy_cache": true,
|
||||||
|
"**/.pytest_cache": true,
|
||||||
|
"**/.ruff_cache": true
|
||||||
|
},
|
||||||
|
"python.testing.unittestEnabled": false,
|
||||||
|
"python.testing.pytestEnabled": true,
|
||||||
|
"[python]": {
|
||||||
|
"editor.defaultFormatter": "charliermarsh.ruff"
|
||||||
|
},
|
||||||
|
"[jsonc]": {
|
||||||
|
"editor.defaultFormatter": "esbenp.prettier-vscode",
|
||||||
|
"editor.formatOnSave": true
|
||||||
|
},
|
||||||
|
"[json]": {
|
||||||
|
"editor.defaultFormatter": "esbenp.prettier-vscode",
|
||||||
|
"editor.formatOnSave": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -136,13 +136,11 @@ class FullyQualifiedName:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(
|
return hash((
|
||||||
(
|
self.name.lower(),
|
||||||
self.name.lower(),
|
self.toolkit_name.lower(),
|
||||||
self.toolkit_name.lower(),
|
(self.toolkit_version or "").lower(),
|
||||||
(self.toolkit_version or "").lower(),
|
))
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def equals_ignoring_version(self, other: "FullyQualifiedName") -> bool:
|
def equals_ignoring_version(self, other: "FullyQualifiedName") -> bool:
|
||||||
"""Check if two fully-qualified tool names are equal, ignoring the version."""
|
"""Check if two fully-qualified tool names are equal, ignoring the version."""
|
||||||
|
|
|
||||||
|
|
@ -103,15 +103,13 @@ class EvaluationResult:
|
||||||
expected: The expected value for the critic.
|
expected: The expected value for the critic.
|
||||||
actual: The actual value for the critic.
|
actual: The actual value for the critic.
|
||||||
"""
|
"""
|
||||||
self.results.append(
|
self.results.append({
|
||||||
{
|
"field": field,
|
||||||
"field": field,
|
**result,
|
||||||
**result,
|
"weight": weight,
|
||||||
"weight": weight,
|
"expected": expected,
|
||||||
"expected": expected,
|
"actual": actual,
|
||||||
"actual": actual,
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def score_tool_selection(self, expected: str, actual: str, weight: float) -> float:
|
def score_tool_selection(self, expected: str, actual: str, weight: float) -> float:
|
||||||
"""
|
"""
|
||||||
|
|
@ -658,12 +656,10 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]:
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
tool_args_list.append(
|
tool_args_list.append((
|
||||||
(
|
tool_call.function.name,
|
||||||
tool_call.function.name,
|
json.loads(tool_call.function.arguments),
|
||||||
json.loads(tool_call.function.arguments),
|
))
|
||||||
)
|
|
||||||
)
|
|
||||||
return tool_args_list
|
return tool_args_list
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,69 +60,9 @@ ignore_missing_imports = "True"
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
target-version = "py39"
|
|
||||||
line-length = 100
|
|
||||||
fix = true
|
|
||||||
select = [
|
|
||||||
# flake8-2020
|
|
||||||
"YTT",
|
|
||||||
# flake8-bandit
|
|
||||||
"S",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-builtins
|
|
||||||
"A",
|
|
||||||
# flake8-comprehensions
|
|
||||||
"C4",
|
|
||||||
# flake8-debugger
|
|
||||||
"T10",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# mccabe
|
|
||||||
"C90",
|
|
||||||
# pycodestyle
|
|
||||||
"E", "W",
|
|
||||||
# pyflakes
|
|
||||||
"F",
|
|
||||||
# pygrep-hooks
|
|
||||||
"PGH",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# ruff
|
|
||||||
"RUF",
|
|
||||||
# tryceratops
|
|
||||||
"TRY",
|
|
||||||
]
|
|
||||||
ignore = [ # TODO work to remove these
|
|
||||||
# LineTooLong
|
|
||||||
"E501",
|
|
||||||
# DoNotAssignLambda
|
|
||||||
"E731",
|
|
||||||
# raise from (cli specific)
|
|
||||||
"TRY200",
|
|
||||||
# Depends function in arg string
|
|
||||||
"B008",
|
|
||||||
# raise from (cli specific)
|
|
||||||
"B904",
|
|
||||||
# long message exceptions
|
|
||||||
"TRY003",
|
|
||||||
# subprocess.Popen
|
|
||||||
"S603",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
preview = true
|
|
||||||
|
|
||||||
[tool.coverage.report]
|
[tool.coverage.report]
|
||||||
skip_empty = true
|
skip_empty = true
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
branch = true
|
branch = true
|
||||||
source = ["arcade"]
|
source = ["arcade"]
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
|
||||||
"tests/*" = ["S101"]
|
|
||||||
|
|
|
||||||
|
|
@ -30,12 +30,10 @@ test_cases = [
|
||||||
|
|
||||||
# Generate tool functions dynamically
|
# Generate tool functions dynamically
|
||||||
def generate_tool_function(input_types: list[type], output_type: type | None):
|
def generate_tool_function(input_types: list[type], output_type: type | None):
|
||||||
input_annotation = ", ".join(
|
input_annotation = ", ".join([
|
||||||
[
|
f"param{i}: Annotated[{input_type.__name__}, 'Param {i + 1}']"
|
||||||
f"param{i}: Annotated[{input_type.__name__}, 'Param {i + 1}']"
|
for i, input_type in enumerate(input_types)
|
||||||
for i, input_type in enumerate(input_types)
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
output_annotation = f" -> {output_type.__name__}" if output_type else ""
|
output_annotation = f" -> {output_type.__name__}" if output_type else ""
|
||||||
|
|
||||||
func_code = f"""
|
func_code = f"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from arcade.core.toolkit import Toolkit
|
|
||||||
import arcade_math
|
import arcade_math
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||||
from arcade.actor.fastapi.actor import FastAPIActor
|
from arcade.actor.fastapi.actor import FastAPIActor
|
||||||
from arcade.client import AsyncArcade
|
from arcade.client import AsyncArcade
|
||||||
from arcade.core.config import config
|
from arcade.core.config import config
|
||||||
|
from arcade.core.toolkit import Toolkit
|
||||||
|
|
||||||
if not config.api or not config.api.key:
|
if not config.api or not config.api.key:
|
||||||
raise ValueError("Arcade API key not set. Please run `arcade login`.")
|
raise ValueError("Arcade API key not set. Please run `arcade login`.")
|
||||||
|
|
@ -47,6 +48,8 @@ async def postChat(request: ChatRequest, tool_choice: str = "execute"):
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
user=config.user.email if config.user else None,
|
user=config.user.email if config.user else None,
|
||||||
)
|
)
|
||||||
return raw_response.choices
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
else:
|
||||||
|
return raw_response.choices
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,7 @@ auth_response = client.auth.authorize(
|
||||||
|
|
||||||
# If authorization is not completed, prompt the user and poll for status
|
# If authorization is not completed, prompt the user and poll for status
|
||||||
if auth_response.status != "completed":
|
if auth_response.status != "completed":
|
||||||
print(
|
print("Please complete the authorization challenge in your browser before continuing:")
|
||||||
"Please complete the authorization challenge in your browser before continuing:"
|
|
||||||
)
|
|
||||||
print(auth_response.auth_url)
|
print(auth_response.auth_url)
|
||||||
input("Press Enter to continue...")
|
input("Press Enter to continue...")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,23 +3,23 @@ import json
|
||||||
from email.message import EmailMessage
|
from email.message import EmailMessage
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
from arcade.core.errors import ToolExecutionError, ToolInputError
|
|
||||||
from googleapiclient.errors import HttpError
|
|
||||||
|
|
||||||
from arcade_google.tools.utils import (
|
|
||||||
DateRange,
|
|
||||||
parse_email,
|
|
||||||
get_draft_url,
|
|
||||||
get_sent_email_url,
|
|
||||||
get_email_in_trash_url,
|
|
||||||
parse_draft_email,
|
|
||||||
)
|
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
from arcade.core.errors import ToolExecutionError, ToolInputError
|
||||||
from arcade.core.schema import ToolContext
|
from arcade.core.schema import ToolContext
|
||||||
from arcade.sdk import tool
|
from arcade.sdk import tool
|
||||||
from arcade.sdk.auth import Google
|
from arcade.sdk.auth import Google
|
||||||
|
from arcade_google.tools.utils import (
|
||||||
|
DateRange,
|
||||||
|
get_draft_url,
|
||||||
|
get_email_in_trash_url,
|
||||||
|
get_sent_email_url,
|
||||||
|
parse_draft_email,
|
||||||
|
parse_email,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Email sending tools
|
# Email sending tools
|
||||||
|
|
@ -42,9 +42,7 @@ async def send_email(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
message = EmailMessage()
|
message = EmailMessage()
|
||||||
message.set_content(body)
|
message.set_content(body)
|
||||||
|
|
@ -62,9 +60,7 @@ async def send_email(
|
||||||
email = {"raw": encoded_message}
|
email = {"raw": encoded_message}
|
||||||
|
|
||||||
# Send the email
|
# Send the email
|
||||||
sent_message = (
|
sent_message = service.users().messages().send(userId="me", body=email).execute()
|
||||||
service.users().messages().send(userId="me", body=email).execute()
|
|
||||||
)
|
|
||||||
return f"Email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
|
return f"Email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
|
||||||
except HttpError as e:
|
except HttpError as e:
|
||||||
raise ToolExecutionError(
|
raise ToolExecutionError(
|
||||||
|
|
@ -91,14 +87,10 @@ async def send_draft_email(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send the draft email
|
# Send the draft email
|
||||||
sent_message = (
|
sent_message = service.users().drafts().send(userId="me", body={"id": id}).execute()
|
||||||
service.users().drafts().send(userId="me", body={"id": id}).execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Construct the URL to the sent email
|
# Construct the URL to the sent email
|
||||||
return f"Draft email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
|
return f"Draft email with ID {sent_message['id']} sent: {get_sent_email_url(sent_message['id'])}"
|
||||||
|
|
@ -133,9 +125,7 @@ async def write_draft_email(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
message = MIMEText(body)
|
message = MIMEText(body)
|
||||||
message["to"] = recipient
|
message["to"] = recipient
|
||||||
|
|
@ -151,9 +141,7 @@ async def write_draft_email(
|
||||||
# Create the draft
|
# Create the draft
|
||||||
draft = {"message": {"raw": raw_message}}
|
draft = {"message": {"raw": raw_message}}
|
||||||
|
|
||||||
draft_message = (
|
draft_message = service.users().drafts().create(userId="me", body=draft).execute()
|
||||||
service.users().drafts().create(userId="me", body=draft).execute()
|
|
||||||
)
|
|
||||||
return f"Draft email with ID {draft_message['id']} created: {get_draft_url(draft_message['id'])}"
|
return f"Draft email with ID {draft_message['id']} created: {get_draft_url(draft_message['id'])}"
|
||||||
except HttpError as e:
|
except HttpError as e:
|
||||||
raise ToolExecutionError(
|
raise ToolExecutionError(
|
||||||
|
|
@ -187,9 +175,7 @@ async def update_draft_email(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
message = MIMEText(body)
|
message = MIMEText(body)
|
||||||
message["to"] = recipient
|
message["to"] = recipient
|
||||||
|
|
@ -236,9 +222,7 @@ async def delete_draft_email(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete the draft
|
# Delete the draft
|
||||||
service.users().drafts().delete(userId="me", id=id).execute()
|
service.users().drafts().delete(userId="me", id=id).execute()
|
||||||
|
|
@ -270,9 +254,7 @@ async def trash_email(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trash the email
|
# Trash the email
|
||||||
service.users().messages().trash(userId="me", id=id).execute()
|
service.users().messages().trash(userId="me", id=id).execute()
|
||||||
|
|
@ -298,33 +280,25 @@ async def trash_email(
|
||||||
async def list_draft_emails(
|
async def list_draft_emails(
|
||||||
context: ToolContext,
|
context: ToolContext,
|
||||||
n_drafts: Annotated[int, "Number of draft emails to read"] = 5,
|
n_drafts: Annotated[int, "Number of draft emails to read"] = 5,
|
||||||
) -> Annotated[
|
) -> Annotated[str, "A JSON string containing a list of draft email details and their IDs"]:
|
||||||
str, "A JSON string containing a list of draft email details and their IDs"
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Lists draft emails in the user's draft mailbox using the Gmail API.
|
Lists draft emails in the user's draft mailbox using the Gmail API.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
listed_drafts = service.users().drafts().list(userId="me").execute()
|
listed_drafts = service.users().drafts().list(userId="me").execute()
|
||||||
|
|
||||||
if not listed_drafts:
|
if not listed_drafts:
|
||||||
return {"emails": []}
|
return {"emails": []}
|
||||||
|
|
||||||
draft_ids = [draft["id"] for draft in listed_drafts.get("drafts", [])][
|
draft_ids = [draft["id"] for draft in listed_drafts.get("drafts", [])][:n_drafts]
|
||||||
:n_drafts
|
|
||||||
]
|
|
||||||
|
|
||||||
emails = []
|
emails = []
|
||||||
for draft_id in draft_ids:
|
for draft_id in draft_ids:
|
||||||
try:
|
try:
|
||||||
draft_data = (
|
draft_data = service.users().drafts().get(userId="me", id=draft_id).execute()
|
||||||
service.users().drafts().get(userId="me", id=draft_id).execute()
|
|
||||||
)
|
|
||||||
draft_details = parse_draft_email(draft_data)
|
draft_details = parse_draft_email(draft_data)
|
||||||
if draft_details:
|
if draft_details:
|
||||||
emails.append(draft_details)
|
emails.append(draft_details)
|
||||||
|
|
@ -352,15 +326,9 @@ async def list_draft_emails(
|
||||||
)
|
)
|
||||||
async def list_emails_by_header(
|
async def list_emails_by_header(
|
||||||
context: ToolContext,
|
context: ToolContext,
|
||||||
sender: Annotated[
|
sender: Annotated[Optional[str], "The name or email address of the sender of the email"] = None,
|
||||||
Optional[str], "The name or email address of the sender of the email"
|
recipient: Annotated[Optional[str], "The name or email address of the recipient"] = None,
|
||||||
] = None,
|
subject: Annotated[Optional[str], "Words to find in the subject of the email"] = None,
|
||||||
recipient: Annotated[
|
|
||||||
Optional[str], "The name or email address of the recipient"
|
|
||||||
] = None,
|
|
||||||
subject: Annotated[
|
|
||||||
Optional[str], "Words to find in the subject of the email"
|
|
||||||
] = None,
|
|
||||||
body: Annotated[Optional[str], "Words to find in the body of the email"] = None,
|
body: Annotated[Optional[str], "Words to find in the body of the email"] = None,
|
||||||
date_range: Annotated[Optional[DateRange], "The date range of the email"] = None,
|
date_range: Annotated[Optional[DateRange], "The date range of the email"] = None,
|
||||||
limit: Annotated[Optional[int], "The maximum number of emails to return"] = 25,
|
limit: Annotated[Optional[int], "The maximum number of emails to return"] = 25,
|
||||||
|
|
@ -394,9 +362,7 @@ async def list_emails_by_header(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform the search
|
# Perform the search
|
||||||
response = (
|
response = (
|
||||||
|
|
@ -413,9 +379,7 @@ async def list_emails_by_header(
|
||||||
emails = []
|
emails = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
try:
|
try:
|
||||||
email_data = (
|
email_data = service.users().messages().get(userId="me", id=msg["id"]).execute()
|
||||||
service.users().messages().get(userId="me", id=msg["id"]).execute()
|
|
||||||
)
|
|
||||||
email_details = parse_email(email_data)
|
email_details = parse_email(email_data)
|
||||||
if email_details:
|
if email_details:
|
||||||
emails.append(email_details)
|
emails.append(email_details)
|
||||||
|
|
@ -449,13 +413,9 @@ async def list_emails(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Set up the Gmail API client
|
# Set up the Gmail API client
|
||||||
service = build(
|
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||||
"gmail", "v1", credentials=Credentials(context.authorization.token)
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = (
|
messages = service.users().messages().list(userId="me").execute().get("messages", [])
|
||||||
service.users().messages().list(userId="me").execute().get("messages", [])
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
return {"emails": []}
|
return {"emails": []}
|
||||||
|
|
@ -463,9 +423,7 @@ async def list_emails(
|
||||||
emails = []
|
emails = []
|
||||||
for msg in messages[:n_emails]:
|
for msg in messages[:n_emails]:
|
||||||
try:
|
try:
|
||||||
email_data = (
|
email_data = service.users().messages().get(userId="me", id=msg["id"]).execute()
|
||||||
service.users().messages().get(userId="me", id=msg["id"]).execute()
|
|
||||||
)
|
|
||||||
email_details = parse_email(email_data)
|
email_details = parse_email(email_data)
|
||||||
if email_details:
|
if email_details:
|
||||||
emails.append(email_details)
|
emails.append(email_details)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from base64 import urlsafe_b64decode
|
|
||||||
import datetime
|
import datetime
|
||||||
from enum import Enum
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Optional
|
from base64 import urlsafe_b64decode
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional, dict
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
@ -30,20 +30,18 @@ class DateRange(Enum):
|
||||||
elif self == DateRange.THIS_MONTH:
|
elif self == DateRange.THIS_MONTH:
|
||||||
comparison_date = today.replace(day=1)
|
comparison_date = today.replace(day=1)
|
||||||
elif self == DateRange.LAST_MONTH:
|
elif self == DateRange.LAST_MONTH:
|
||||||
comparison_date = (
|
comparison_date = (today.replace(day=1) - datetime.timedelta(days=1)).replace(day=1)
|
||||||
today.replace(day=1) - datetime.timedelta(days=1)
|
|
||||||
).replace(day=1)
|
|
||||||
elif self == DateRange.THIS_YEAR:
|
elif self == DateRange.THIS_YEAR:
|
||||||
comparison_date = today.replace(month=1, day=1)
|
comparison_date = today.replace(month=1, day=1)
|
||||||
elif self == DateRange.LAST_MONTH:
|
elif self == DateRange.LAST_MONTH:
|
||||||
comparison_date = (
|
comparison_date = (today.replace(month=1, day=1) - datetime.timedelta(days=1)).replace(
|
||||||
today.replace(month=1, day=1) - datetime.timedelta(days=1)
|
month=1, day=1
|
||||||
).replace(month=1, day=1)
|
)
|
||||||
|
|
||||||
return result + comparison_date.strftime("%Y/%m/%d")
|
return result + comparison_date.strftime("%Y/%m/%d")
|
||||||
|
|
||||||
|
|
||||||
def parse_email(email_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
|
def parse_email(email_data: dict[str, Any]) -> Optional[dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Parse email data and extract relevant information.
|
Parse email data and extract relevant information.
|
||||||
|
|
||||||
|
|
@ -71,7 +69,7 @@ def parse_email(email_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_draft_email(draft_email_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
|
def parse_draft_email(draft_email_data: dict[str, Any]) -> Optional[dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Parse draft email data and extract relevant information.
|
Parse draft email data and extract relevant information.
|
||||||
|
|
||||||
|
|
@ -112,7 +110,7 @@ def get_email_in_trash_url(email_id):
|
||||||
return f"https://mail.google.com/mail/u/0/#trash/{email_id}"
|
return f"https://mail.google.com/mail/u/0/#trash/{email_id}"
|
||||||
|
|
||||||
|
|
||||||
def _get_email_body(payload: Dict[str, Any]) -> Optional[str]:
|
def _get_email_body(payload: dict[str, Any]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Extract email body from payload.
|
Extract email body from payload.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,28 @@
|
||||||
import json
|
import json
|
||||||
from arcade.core.errors import ToolExecutionError
|
from unittest.mock import MagicMock, patch
|
||||||
from arcade_google.tools.utils import parse_draft_email, parse_email
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from arcade_google.tools.gmail import (
|
from arcade_google.tools.gmail import (
|
||||||
send_email,
|
|
||||||
write_draft_email,
|
|
||||||
update_draft_email,
|
|
||||||
send_draft_email,
|
|
||||||
delete_draft_email,
|
delete_draft_email,
|
||||||
list_draft_emails,
|
list_draft_emails,
|
||||||
list_emails_by_header,
|
|
||||||
list_emails,
|
list_emails,
|
||||||
|
list_emails_by_header,
|
||||||
|
send_draft_email,
|
||||||
|
send_email,
|
||||||
trash_email,
|
trash_email,
|
||||||
|
update_draft_email,
|
||||||
|
write_draft_email,
|
||||||
)
|
)
|
||||||
|
from arcade_google.tools.utils import parse_draft_email, parse_email
|
||||||
from arcade.core.schema import ToolContext, ToolAuthorizationContext
|
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
from arcade.core.errors import ToolExecutionError
|
||||||
|
from arcade.core.schema import ToolAuthorizationContext, ToolContext
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_context():
|
def mock_context():
|
||||||
mock_auth = ToolAuthorizationContext(token="fake-token")
|
mock_auth = ToolAuthorizationContext(token="fake-token") # noqa: S106
|
||||||
return ToolContext(authorization=mock_auth)
|
return ToolContext(authorization=mock_auth)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -214,9 +215,7 @@ async def test_get_draft_emails(mock_parse_draft_email, mock_build, mock_context
|
||||||
mock_build.return_value = mock_service
|
mock_build.return_value = mock_service
|
||||||
|
|
||||||
# Mock the response from the Gmail list drafts API
|
# Mock the response from the Gmail list drafts API
|
||||||
mock_service.users().drafts().list().execute.return_value = (
|
mock_service.users().drafts().list().execute.return_value = mock_drafts_list_response
|
||||||
mock_drafts_list_response
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the response from the Gmail get drafts API
|
# Mock the response from the Gmail get drafts API
|
||||||
mock_service.users().drafts().get().execute.return_value = mock_drafts_get_response
|
mock_service.users().drafts().get().execute.return_value = mock_drafts_get_response
|
||||||
|
|
@ -291,22 +290,16 @@ async def test_search_emails_by_header(mock_parse_email, mock_build, mock_contex
|
||||||
mock_build.return_value = mock_service
|
mock_build.return_value = mock_service
|
||||||
|
|
||||||
# Mock the response from the Gmail list messages API
|
# Mock the response from the Gmail list messages API
|
||||||
mock_service.users().messages().list().execute.return_value = (
|
mock_service.users().messages().list().execute.return_value = mock_messages_list_response
|
||||||
mock_messages_list_response
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the response from the Gmail get messages API
|
# Mock the response from the Gmail get messages API
|
||||||
mock_service.users().messages().get().execute.return_value = (
|
mock_service.users().messages().get().execute.return_value = mock_messages_get_response
|
||||||
mock_messages_get_response
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the parse_email function since parse_email doesn't accept object of type MagicMock
|
# Mock the parse_email function since parse_email doesn't accept object of type MagicMock
|
||||||
mock_parse_email.return_value = parse_email(mock_messages_get_response)
|
mock_parse_email.return_value = parse_email(mock_messages_get_response)
|
||||||
|
|
||||||
# Test happy path
|
# Test happy path
|
||||||
result = await list_emails_by_header(
|
result = await list_emails_by_header(context=mock_context, sender="noreply@github.com", limit=2)
|
||||||
context=mock_context, sender="noreply@github.com", limit=2
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
result_json = json.loads(result)
|
result_json = json.loads(result)
|
||||||
|
|
@ -322,9 +315,7 @@ async def test_search_emails_by_header(mock_parse_email, mock_build, mock_contex
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ToolExecutionError):
|
with pytest.raises(ToolExecutionError):
|
||||||
await list_emails_by_header(
|
await list_emails_by_header(context=mock_context, sender="noreply@github.com", limit=2)
|
||||||
context=mock_context, sender="noreply@github.com", limit=2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -373,14 +364,10 @@ async def test_get_emails(mock_parse_email, mock_build, mock_context):
|
||||||
mock_build.return_value = mock_service
|
mock_build.return_value = mock_service
|
||||||
|
|
||||||
# Mock the response from the Gmail list messages API
|
# Mock the response from the Gmail list messages API
|
||||||
mock_service.users().messages().list().execute.return_value = (
|
mock_service.users().messages().list().execute.return_value = mock_messages_list_response
|
||||||
mock_messages_list_response
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the Gmail get messages API
|
# Mock the Gmail get messages API
|
||||||
mock_service.users().messages().get().execute.return_value = (
|
mock_service.users().messages().get().execute.return_value = mock_messages_get_response
|
||||||
mock_messages_get_response
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the parse_email function since parse_email doesn't accept object of type MagicMock
|
# Mock the parse_email function since parse_email doesn't accept object of type MagicMock
|
||||||
mock_parse_email.return_value = parse_email(mock_messages_get_response)
|
mock_parse_email.return_value = parse_email(mock_messages_get_response)
|
||||||
|
|
|
||||||
|
|
@ -43,9 +43,7 @@ def math_eval_suite():
|
||||||
],
|
],
|
||||||
rubric=rubric,
|
rubric=rubric,
|
||||||
critics=[
|
critics=[
|
||||||
BinaryCritic(
|
BinaryCritic(critic_field="a", weight=0.5), # TODO: weight should be optional
|
||||||
critic_field="a", weight=0.5
|
|
||||||
), # TODO: weight should be optional
|
|
||||||
BinaryCritic(critic_field="b", weight=0.5),
|
BinaryCritic(critic_field="b", weight=0.5),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
from arcade_math.tools.arithmetic import (
|
from arcade_math.tools.arithmetic import (
|
||||||
add,
|
add,
|
||||||
subtract,
|
|
||||||
multiply,
|
|
||||||
divide,
|
divide,
|
||||||
|
multiply,
|
||||||
sqrt,
|
sqrt,
|
||||||
|
subtract,
|
||||||
sum_list,
|
sum_list,
|
||||||
sum_range,
|
sum_range,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import serpapi
|
|
||||||
from typing import Annotated, Any, Optional
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
|
import serpapi
|
||||||
|
|
||||||
from arcade.sdk import tool
|
from arcade.sdk import tool
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -165,9 +165,7 @@ def slack_eval_suite() -> EvalSuite:
|
||||||
],
|
],
|
||||||
critics=[
|
critics=[
|
||||||
SimilarityCritic(critic_field="user_name", weight=0.7),
|
SimilarityCritic(critic_field="user_name", weight=0.7),
|
||||||
SimilarityCritic(
|
SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.6),
|
||||||
critic_field="message", weight=0.3, similarity_threshold=0.6
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from arcade.core.errors import ToolExecutionError
|
|
||||||
from arcade.sdk.auth import X
|
|
||||||
import requests
|
import requests
|
||||||
from arcade.sdk import tool
|
|
||||||
|
|
||||||
|
from arcade.core.errors import ToolExecutionError
|
||||||
from arcade.core.schema import ToolContext
|
from arcade.core.schema import ToolContext
|
||||||
|
from arcade.sdk import tool
|
||||||
|
from arcade.sdk.auth import X
|
||||||
from arcade_x.tools.utils import get_tweet_url, parse_search_recent_tweets_response
|
from arcade_x.tools.utils import get_tweet_url, parse_search_recent_tweets_response
|
||||||
|
|
||||||
TWEETS_URL = "https://api.x.com/2/tweets"
|
TWEETS_URL = "https://api.x.com/2/tweets"
|
||||||
|
|
@ -33,9 +33,7 @@ def post_tweet(
|
||||||
)
|
)
|
||||||
|
|
||||||
tweet_id = response.json()["data"]["id"]
|
tweet_id = response.json()["data"]["id"]
|
||||||
return (
|
return f"Tweet with id {tweet_id} posted successfully. URL: {get_tweet_url(tweet_id)}"
|
||||||
f"Tweet with id {tweet_id} posted successfully. URL: {get_tweet_url(tweet_id)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool(requires_auth=X(scopes=["tweet.read", "tweet.write", "users.read"]))
|
@tool(requires_auth=X(scopes=["tweet.read", "tweet.write", "users.read"]))
|
||||||
|
|
@ -74,11 +72,11 @@ def search_recent_tweets_by_username(
|
||||||
}
|
}
|
||||||
params = {
|
params = {
|
||||||
"query": f"from:{username}",
|
"query": f"from:{username}",
|
||||||
"max_results": max(
|
"max_results": max(max_results, 10), # X API does not allow 'max_results' less than 10
|
||||||
max_results, 10
|
|
||||||
), # X API does not allow 'max_results' less than 10
|
|
||||||
}
|
}
|
||||||
url = "https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username"
|
url = (
|
||||||
|
"https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username"
|
||||||
|
)
|
||||||
|
|
||||||
response = requests.get(url, headers=headers, params=params)
|
response = requests.get(url, headers=headers, params=params)
|
||||||
|
|
||||||
|
|
@ -95,12 +93,8 @@ def search_recent_tweets_by_username(
|
||||||
@tool(requires_auth=X(scopes=["tweet.read", "users.read"]))
|
@tool(requires_auth=X(scopes=["tweet.read", "users.read"]))
|
||||||
def search_recent_tweets_by_keywords(
|
def search_recent_tweets_by_keywords(
|
||||||
context: ToolContext,
|
context: ToolContext,
|
||||||
keywords: Annotated[
|
keywords: Annotated[list[str], "List of keywords that must be present in the tweet"] = None,
|
||||||
list[str], "List of keywords that must be present in the tweet"
|
phrases: Annotated[list[str], "List of phrases that must be present in the tweet"] = None,
|
||||||
] = None,
|
|
||||||
phrases: Annotated[
|
|
||||||
list[str], "List of phrases that must be present in the tweet"
|
|
||||||
] = None,
|
|
||||||
max_results: Annotated[
|
max_results: Annotated[
|
||||||
int, "The maximum number of results to return. Cannot be less than 10"
|
int, "The maximum number of results to return. Cannot be less than 10"
|
||||||
] = 10,
|
] = 10,
|
||||||
|
|
@ -122,11 +116,11 @@ def search_recent_tweets_by_keywords(
|
||||||
query = " ".join([f'"{phrase}"' for phrase in phrases]) + " ".join(keywords)
|
query = " ".join([f'"{phrase}"' for phrase in phrases]) + " ".join(keywords)
|
||||||
params = {
|
params = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"max_results": max(
|
"max_results": max(max_results, 10), # X API does not allow 'max_results' less than 10
|
||||||
max_results, 10
|
|
||||||
), # X API does not allow 'max_results' less than 10
|
|
||||||
}
|
}
|
||||||
url = "https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username"
|
url = (
|
||||||
|
"https://api.x.com/2/tweets/search/recent?expansions=author_id&user.fields=id,name,username"
|
||||||
|
)
|
||||||
|
|
||||||
response = requests.get(url, headers=headers, params=params)
|
response = requests.get(url, headers=headers, params=params)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from arcade.core.errors import ToolExecutionError
|
|
||||||
from arcade.sdk.auth import X
|
|
||||||
import requests
|
|
||||||
from arcade.sdk import tool
|
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from arcade.core.errors import ToolExecutionError
|
||||||
from arcade.core.schema import ToolContext
|
from arcade.core.schema import ToolContext
|
||||||
|
from arcade.sdk import tool
|
||||||
|
from arcade.sdk.auth import X
|
||||||
|
|
||||||
|
|
||||||
# Users Lookup Tools. See developer docs for additional available query parameters: https://developer.x.com/en/docs/x-api/users/lookup/api-reference
|
# Users Lookup Tools. See developer docs for additional available query parameters: https://developer.x.com/en/docs/x-api/users/lookup/api-reference
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from requests import Response
|
from requests import Response
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,9 +39,7 @@ def parse_search_recent_tweets_response(response: Response) -> str:
|
||||||
for tweet in tweets_data["data"]:
|
for tweet in tweets_data["data"]:
|
||||||
tweet["tweet_url"] = get_tweet_url(tweet["id"])
|
tweet["tweet_url"] = get_tweet_url(tweet["id"])
|
||||||
|
|
||||||
for tweet_data, user_data in zip(
|
for tweet_data, user_data in zip(tweets_data["data"], tweets_data["includes"]["users"]):
|
||||||
tweets_data["data"], tweets_data["includes"]["users"]
|
|
||||||
):
|
|
||||||
tweet_data["author_username"] = user_data["username"]
|
tweet_data["author_username"] = user_data["username"]
|
||||||
tweet_data["author_name"] = user_data["name"]
|
tweet_data["author_name"] = user_data["name"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,9 +43,7 @@ def x_eval_suite() -> EvalSuite:
|
||||||
expected_tool_calls=[
|
expected_tool_calls=[
|
||||||
(
|
(
|
||||||
post_tweet,
|
post_tweet,
|
||||||
{
|
{"tweet_text": "Hello World! Exciting stuff is happening over at Arcade AI!"},
|
||||||
"tweet_text": "Hello World! Exciting stuff is happening over at Arcade AI!"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
critics=[
|
critics=[
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue