From 07c52100f3e2daca69b2123e206515ac51e2c064 Mon Sep 17 00:00:00 2001 From: Eric Gustin <34000337+EricGustin@users.noreply.github.com> Date: Wed, 9 Jul 2025 16:00:09 -0700 Subject: [PATCH] Split and rename multiple toolkits (#438) # PR Description ## Split toolkits This PR splits the `Microsoft`, `Google`, and `Search` toolkits into multiple toolkits each. * `Microsoft` --> `OutlookCalendar`, `OutlookMail`. * `Google` -----> `GoogleCalendar`, `GoogleContacts`, `GoogleDocs`, `GoogleDrive`, `Gmail`, `GoogleSheets` * `Search` -----> `GoogleFinance`, `GoogleFlights`, `GoogleHotels`, `GoogleJobs`, `GoogleMaps`, `GoogleNews`, `GoogleSearch`, `GoogleShopping`, `Walmart`, `Youtube` > The original monolithic toolkits (`Microsoft`, `Google`, `Search`) are not removed in this PR. The plan is to keep those toolkits around while we > 1. Stop documenting the toolkits, > 2. Stop displaying the toolkits in the dashboard, and > 3. Help customers migrate over to the new split toolkits. ## Rename toolkits This PR renames the following toolkits * `Web` ------------> `Firecrawl` * `CodeSandbox` ---> `E2B` > The `Web` and `CodeSandbox` toolkits are not removed in this PR. The plan is to keep them around while we > 1. Stop documenting the toolkits, > 2. Stop displaying the toolkits in the dashboard, and > 3. Help customers migrate over to the new renamed toolkits. ## Rename tools Since toolkit names were changed, this called for some tools to be renamed as well. * `GoogleSearch.SearchGoogle` ----------------> `GoogleSearch.Search` * `GoogleShopping.SearchShoppingProducts` ---> `GoogleShopping.SearchProducts` * `Walmart.SearchWalmartProducts` ------------> `Walmart.SearchProducts` * `Walmart.GetWalmartProductDetails` ---------> `Walmart.GetProductDetails` * `Youtube.SearchYoutubeVideos` --------------> `Youtube.SearchForVideos` ## Google File Picker Improvements to the Google File Picker experience were also added in this PR. The following tools will ALWAYS provide llm_instructions in their response to "let the end-user know that they have the option to select more files via the file picker url if they want to": * `GoogleDocs.SearchDocuments` * `GoogleDocs.SearchAndRetrieveDocuments` * `GoogleDrive.GetFileTreeStructure` The following tools will only provide the file picker URL if a 404 or 403 from the Google API: * `GoogleDocs.InsertTextAtEndOfDocument` * `GoogleDocs.GetDocumentById` * `GoogleSheets.GetSpreadsheet` * `GoogleSheets.WriteToCell` Also, a standalone `GoogleDrive.GenerateGoogleFilePickerUrl` tool exists. ## Other * The `SearchDocuments` and `SearchAndRetrieveDocuments` tools used to be organized within the Drive portion of the Google toolkit, but I moved these into the new GoogleDocs toolkit because they are specific to Docs. # Progress - [x] `OutlookCalendar` - [x] `OutlookMail` - [x] `GoogleFinance` - [x] `GoogleFlights` - [x] `GoogleHotels` - [x] `GoogleJobs` - [x] `GoogleMaps` - [x] `GoogleNews` - [x] `GoogleSearch` - [x] `GoogleShopping` - [x] `Walmart` - [x] `Youtube` - [x] `GoogleCalendar` - [x] `GoogleContacts` - [x] `GoogleDocs` - [x] `GoogleDrive` - [x] `Gmail` - [x] `GoogleSheets` - [x] `Firecrawl` - [x] `E2B` - [x] File picker # Discussion * Repeated code is a consequence of splitting toolkits that use the same provider. I am open to any ideas that would allow multiple toolkits to reference common code. Comment your ideas in this PR. --- .github/workflows/test-toolkits.yml | 11 +- docker/toolkits.txt | 20 + toolkits/e2b/.pre-commit-config.yaml | 18 + toolkits/e2b/.ruff.toml | 47 + toolkits/e2b/LICENSE | 21 + toolkits/e2b/Makefile | 55 + toolkits/e2b/arcade_e2b/__init__.py | 3 + toolkits/e2b/arcade_e2b/enums.py | 10 + toolkits/e2b/arcade_e2b/tools/__init__.py | 4 + toolkits/e2b/arcade_e2b/tools/create_chart.py | 31 + toolkits/e2b/arcade_e2b/tools/run_code.py | 27 + toolkits/e2b/evals/eval_e2b.py | 120 +++ toolkits/e2b/pyproject.toml | 57 ++ toolkits/e2b/tests/__init__.py | 0 toolkits/e2b/tests/test_e2b.py | 74 ++ toolkits/firecrawl/.pre-commit-config.yaml | 18 + toolkits/firecrawl/.ruff.toml | 47 + toolkits/firecrawl/LICENSE | 21 + toolkits/firecrawl/Makefile | 55 + .../firecrawl/arcade_firecrawl/__init__.py | 17 + toolkits/firecrawl/arcade_firecrawl/enums.py | 11 + .../arcade_firecrawl/tools/__init__.py | 17 + .../firecrawl/arcade_firecrawl/tools/crawl.py | 121 +++ .../firecrawl/arcade_firecrawl/tools/map.py | 33 + .../arcade_firecrawl/tools/scrape.py | 49 + toolkits/firecrawl/evals/eval_firecrawl.py | 244 +++++ toolkits/firecrawl/pyproject.toml | 54 + toolkits/firecrawl/tests/__init__.py | 0 toolkits/firecrawl/tests/test_firecrawl.py | 129 +++ toolkits/gmail/.pre-commit-config.yaml | 18 + toolkits/gmail/.ruff.toml | 46 + toolkits/gmail/Makefile | 55 + toolkits/gmail/arcade_gmail/__init__.py | 0 toolkits/gmail/arcade_gmail/constants.py | 18 + toolkits/gmail/arcade_gmail/enums.py | 11 + toolkits/gmail/arcade_gmail/exceptions.py | 19 + toolkits/gmail/arcade_gmail/tools/__init__.py | 39 + toolkits/gmail/arcade_gmail/tools/gmail.py | 664 ++++++++++++ toolkits/gmail/arcade_gmail/utils.py | 509 +++++++++ toolkits/gmail/evals/eval_google_gmail.py | 431 ++++++++ toolkits/gmail/pyproject.toml | 64 ++ toolkits/gmail/tests/__init__.py | 0 toolkits/gmail/tests/test_gmail.py | 951 +++++++++++++++++ .../google_calendar/.pre-commit-config.yaml | 18 + toolkits/google_calendar/.ruff.toml | 46 + toolkits/google_calendar/Makefile | 55 + .../arcade_google_calendar/__init__.py | 17 + .../arcade_google_calendar/enums.py | 14 + .../arcade_google_calendar/tools/__init__.py | 17 + .../arcade_google_calendar/tools/calendar.py | 510 +++++++++ .../arcade_google_calendar/utils.py | 249 +++++ .../evals/eval_google_calendar.py | 215 ++++ toolkits/google_calendar/pyproject.toml | 63 ++ toolkits/google_calendar/tests/__init__.py | 0 .../google_calendar/tests/test_calendar.py | 582 +++++++++++ .../google_contacts/.pre-commit-config.yaml | 18 + toolkits/google_contacts/.ruff.toml | 46 + toolkits/google_contacts/Makefile | 55 + .../arcade_google_contacts/__init__.py | 7 + .../arcade_google_contacts/constants.py | 1 + .../arcade_google_contacts/tools/__init__.py | 7 + .../arcade_google_contacts/tools/contacts.py | 96 ++ .../arcade_google_contacts/utils.py | 49 + .../evals/eval_google_contacts.py | 135 +++ toolkits/google_contacts/pyproject.toml | 63 ++ toolkits/google_contacts/tests/__init__.py | 0 .../google_contacts/tests/test_contacts.py | 100 ++ toolkits/google_docs/.pre-commit-config.yaml | 18 + toolkits/google_docs/.ruff.toml | 46 + toolkits/google_docs/Makefile | 55 + .../arcade_google_docs/__init__.py | 17 + .../arcade_google_docs/decorators.py | 24 + .../arcade_google_docs/doc_to_html.py | 99 ++ .../arcade_google_docs/doc_to_markdown.py | 64 ++ .../google_docs/arcade_google_docs/enum.py | 116 +++ .../arcade_google_docs/file_picker.py | 49 + .../arcade_google_docs/templates.py | 5 + .../arcade_google_docs/tools/__init__.py | 19 + .../arcade_google_docs/tools/create.py | 82 ++ .../arcade_google_docs/tools/get.py | 35 + .../arcade_google_docs/tools/search.py | 219 ++++ .../arcade_google_docs/tools/update.py | 60 ++ .../google_docs/arcade_google_docs/utils.py | 119 +++ toolkits/google_docs/conftest.py | 967 ++++++++++++++++++ .../google_docs/evals/eval_google_docs.py | 384 +++++++ toolkits/google_docs/pyproject.toml | 62 ++ toolkits/google_docs/tests/__init__.py | 0 .../google_docs/tests/test_doc_to_markdown.py | 10 + .../google_docs/tests/test_google_docs.py | 179 ++++ toolkits/google_docs/tests/test_search.py | 276 +++++ toolkits/google_drive/.pre-commit-config.yaml | 18 + toolkits/google_drive/.ruff.toml | 46 + toolkits/google_drive/Makefile | 55 + .../arcade_google_drive/__init__.py | 3 + .../google_drive/arcade_google_drive/enums.py | 116 +++ .../arcade_google_drive/templates.py | 5 + .../arcade_google_drive/tools/__init__.py | 6 + .../arcade_google_drive/tools/drive.py | 167 +++ .../google_drive/arcade_google_drive/utils.py | 114 +++ toolkits/google_drive/conftest.py | 197 ++++ .../google_drive/evals/eval_google_drive.py | 131 +++ .../evals/eval_tools_understand_filepicker.py | 70 ++ toolkits/google_drive/pyproject.toml | 62 ++ toolkits/google_drive/tests/__init__.py | 0 toolkits/google_drive/tests/test_drive.py | 238 +++++ .../google_finance/.pre-commit-config.yaml | 18 + toolkits/google_finance/.ruff.toml | 47 + toolkits/google_finance/LICENSE | 21 + toolkits/google_finance/Makefile | 55 + .../arcade_google_finance/__init__.py | 3 + .../arcade_google_finance/enums.py | 12 + .../arcade_google_finance/tools/__init__.py | 3 + .../tools/google_finance.py | 86 ++ .../arcade_google_finance/utils.py | 48 + toolkits/google_finance/pyproject.toml | 56 + .../google_flights/.pre-commit-config.yaml | 18 + toolkits/google_flights/.ruff.toml | 47 + toolkits/google_flights/LICENSE | 21 + toolkits/google_flights/Makefile | 55 + .../arcade_google_flights/__init__.py | 3 + .../arcade_google_flights/enums.py | 53 + .../arcade_google_flights/tools/__init__.py | 5 + .../tools/google_flights.py | 61 ++ .../arcade_google_flights/utils.py | 68 ++ toolkits/google_flights/pyproject.toml | 56 + .../google_hotels/.pre-commit-config.yaml | 18 + toolkits/google_hotels/.ruff.toml | 47 + toolkits/google_hotels/LICENSE | 21 + toolkits/google_hotels/Makefile | 55 + .../arcade_google_hotels/__init__.py | 3 + .../arcade_google_hotels/enums.py | 17 + .../arcade_google_hotels/tools/__init__.py | 3 + .../tools/google_hotels.py | 58 ++ .../arcade_google_hotels/utils.py | 48 + toolkits/google_hotels/pyproject.toml | 56 + toolkits/google_jobs/.pre-commit-config.yaml | 18 + toolkits/google_jobs/.ruff.toml | 47 + toolkits/google_jobs/LICENSE | 21 + toolkits/google_jobs/Makefile | 55 + .../arcade_google_jobs/__init__.py | 3 + .../arcade_google_jobs/constants.py | 5 + .../google_jobs/arcade_google_jobs/enums.py | 0 .../arcade_google_jobs/exceptions.py | 17 + .../arcade_google_jobs/google_data.py | 33 + .../arcade_google_jobs/tools/__init__.py | 3 + .../arcade_google_jobs/tools/google_jobs.py | 65 ++ .../google_jobs/arcade_google_jobs/utils.py | 48 + .../google_jobs/evals/eval_google_jobs.py | 157 +++ toolkits/google_jobs/pyproject.toml | 56 + toolkits/google_jobs/tests/__init__.py | 0 .../google_jobs/tests/test_google_jobs.py | 90 ++ toolkits/google_maps/.pre-commit-config.yaml | 18 + toolkits/google_maps/.ruff.toml | 47 + toolkits/google_maps/LICENSE | 21 + toolkits/google_maps/Makefile | 55 + .../arcade_google_maps/__init__.py | 6 + .../arcade_google_maps/constants.py | 14 + .../google_maps/arcade_google_maps/enums.py | 35 + .../arcade_google_maps/exceptions.py | 25 + .../arcade_google_maps/google_data.py | 281 +++++ .../arcade_google_maps/tools/__init__.py | 9 + .../arcade_google_maps/tools/google_maps.py | 100 ++ .../google_maps/arcade_google_maps/utils.py | 175 ++++ .../evals/eval_google_maps_directions.py | 226 ++++ toolkits/google_maps/pyproject.toml | 56 + toolkits/google_maps/tests/__init__.py | 0 .../tests/test_google_maps_directions.py | 131 +++ toolkits/google_news/.pre-commit-config.yaml | 18 + toolkits/google_news/.ruff.toml | 47 + toolkits/google_news/LICENSE | 21 + toolkits/google_news/Makefile | 55 + .../arcade_google_news/__init__.py | 3 + .../arcade_google_news/constants.py | 6 + .../arcade_google_news/exceptions.py | 25 + .../arcade_google_news/google_data.py | 281 +++++ .../arcade_google_news/tools/__init__.py | 3 + .../arcade_google_news/tools/google_news.py | 47 + .../google_news/arcade_google_news/utils.py | 64 ++ toolkits/google_news/pyproject.toml | 56 + .../google_search/.pre-commit-config.yaml | 18 + toolkits/google_search/.ruff.toml | 46 + toolkits/google_search/LICENSE | 21 + toolkits/google_search/Makefile | 55 + .../arcade_google_search/__init__.py | 3 + .../arcade_google_search/tools/__init__.py | 3 + .../tools/google_search.py | 21 + .../arcade_google_search/utils.py | 48 + .../google_search/evals/eval_google_search.py | 240 +++++ toolkits/google_search/pyproject.toml | 59 ++ toolkits/google_search/tests/__init__.py | 0 .../google_search/tests/test_google_search.py | 49 + toolkits/google_search/tests/test_utils.py | 68 ++ .../google_sheets/.pre-commit-config.yaml | 18 + toolkits/google_sheets/.ruff.toml | 46 + toolkits/google_sheets/LICENSE | 21 + toolkits/google_sheets/Makefile | 55 + .../arcade_google_sheets/__init__.py | 7 + .../arcade_google_sheets/constants.py | 2 + .../arcade_google_sheets/decorators.py | 24 + .../arcade_google_sheets/enums.py | 25 + .../arcade_google_sheets/file_picker.py | 49 + .../arcade_google_sheets/models.py | 241 +++++ .../arcade_google_sheets/tools/__init__.py | 4 + .../arcade_google_sheets/tools/read.py | 42 + .../arcade_google_sheets/tools/write.py | 114 +++ .../arcade_google_sheets/types.py | 1 + .../arcade_google_sheets/utils.py | 548 ++++++++++ .../google_sheets/evals/eval_google_sheets.py | 169 +++ toolkits/google_sheets/pyproject.toml | 63 ++ toolkits/google_sheets/tests/__init__.py | 0 .../google_sheets/tests/test_sheets_models.py | 84 ++ .../google_sheets/tests/test_sheets_utils.py | 542 ++++++++++ .../google_shopping/.pre-commit-config.yaml | 18 + toolkits/google_shopping/.ruff.toml | 46 + toolkits/google_shopping/LICENSE | 21 + toolkits/google_shopping/Makefile | 55 + .../arcade_google_shopping/__init__.py | 3 + .../arcade_google_shopping/constants.py | 10 + .../arcade_google_shopping/exceptions.py | 25 + .../arcade_google_shopping/google_data.py | 468 +++++++++ .../arcade_google_shopping/tools/__init__.py | 3 + .../tools/google_shopping.py | 66 ++ .../arcade_google_shopping/utils.py | 117 +++ toolkits/google_shopping/pyproject.toml | 59 ++ .../outlook_calendar/.pre-commit-config.yaml | 18 + toolkits/outlook_calendar/.ruff.toml | 47 + toolkits/outlook_calendar/LICENSE | 21 + toolkits/outlook_calendar/Makefile | 55 + .../arcade_outlook_calendar/__init__.py | 7 + .../arcade_outlook_calendar/_utils.py | 225 ++++ .../arcade_outlook_calendar/client.py | 26 + .../arcade_outlook_calendar/constants.py | 138 +++ .../arcade_outlook_calendar/models.py | 288 ++++++ .../arcade_outlook_calendar/tools/__init__.py | 7 + .../tools/create_event.py | 81 ++ .../tools/get_event.py | 29 + .../tools/list_events_in_time_range.py | 59 ++ .../evals/additional_messages.py | 28 + .../evals/eval_create_event.py | 94 ++ .../outlook_calendar/evals/eval_get_event.py | 53 + .../evals/eval_list_events_in_time_range.py | 77 ++ toolkits/outlook_calendar/pyproject.toml | 61 ++ toolkits/outlook_calendar/tests/__init__.py | 0 .../outlook_calendar/tests/test_models.py | 385 +++++++ toolkits/outlook_calendar/tests/test_utils.py | 118 +++ toolkits/outlook_mail/.pre-commit-config.yaml | 18 + toolkits/outlook_mail/.ruff.toml | 47 + toolkits/outlook_mail/LICENSE | 21 + toolkits/outlook_mail/Makefile | 55 + .../arcade_outlook_mail/__init__.py | 24 + .../arcade_outlook_mail/_utils.py | 120 +++ .../arcade_outlook_mail/client.py | 26 + .../arcade_outlook_mail/constants.py | 18 + .../outlook_mail/arcade_outlook_mail/enums.py | 65 ++ .../arcade_outlook_mail/message.py | 218 ++++ .../arcade_outlook_mail/tools/__init__.py | 28 + .../arcade_outlook_mail/tools/read.py | 122 +++ .../arcade_outlook_mail/tools/send.py | 94 ++ .../arcade_outlook_mail/tools/write.py | 115 +++ .../outlook_mail/evals/additional_messages.py | 83 ++ toolkits/outlook_mail/evals/eval_read.py | 210 ++++ toolkits/outlook_mail/evals/eval_send.py | 127 +++ toolkits/outlook_mail/evals/eval_write.py | 104 ++ toolkits/outlook_mail/pyproject.toml | 60 ++ toolkits/outlook_mail/tests/__init__.py | 0 toolkits/outlook_mail/tests/test_message.py | 249 +++++ toolkits/outlook_mail/tests/test_recipient.py | 43 + toolkits/outlook_mail/tests/test_utils.py | 55 + toolkits/walmart/.pre-commit-config.yaml | 18 + toolkits/walmart/.ruff.toml | 46 + toolkits/walmart/LICENSE | 21 + toolkits/walmart/Makefile | 55 + toolkits/walmart/arcade_walmart/__init__.py | 3 + toolkits/walmart/arcade_walmart/enums.py | 21 + .../walmart/arcade_walmart/tools/__init__.py | 3 + .../walmart/arcade_walmart/tools/walmart.py | 95 ++ toolkits/walmart/arcade_walmart/utils.py | 120 +++ toolkits/walmart/pyproject.toml | 59 ++ toolkits/youtube/.pre-commit-config.yaml | 18 + toolkits/youtube/.ruff.toml | 46 + toolkits/youtube/LICENSE | 21 + toolkits/youtube/Makefile | 55 + toolkits/youtube/arcade_youtube/__init__.py | 3 + toolkits/youtube/arcade_youtube/constants.py | 7 + toolkits/youtube/arcade_youtube/exceptions.py | 25 + .../youtube/arcade_youtube/google_data.py | 281 +++++ .../youtube/arcade_youtube/tools/__init__.py | 3 + .../youtube/arcade_youtube/tools/youtube.py | 101 ++ toolkits/youtube/arcade_youtube/utils.py | 169 +++ toolkits/youtube/pyproject.toml | 59 ++ 290 files changed, 22664 insertions(+), 1 deletion(-) create mode 100644 toolkits/e2b/.pre-commit-config.yaml create mode 100644 toolkits/e2b/.ruff.toml create mode 100644 toolkits/e2b/LICENSE create mode 100644 toolkits/e2b/Makefile create mode 100644 toolkits/e2b/arcade_e2b/__init__.py create mode 100644 toolkits/e2b/arcade_e2b/enums.py create mode 100644 toolkits/e2b/arcade_e2b/tools/__init__.py create mode 100644 toolkits/e2b/arcade_e2b/tools/create_chart.py create mode 100644 toolkits/e2b/arcade_e2b/tools/run_code.py create mode 100644 toolkits/e2b/evals/eval_e2b.py create mode 100644 toolkits/e2b/pyproject.toml create mode 100644 toolkits/e2b/tests/__init__.py create mode 100644 toolkits/e2b/tests/test_e2b.py create mode 100644 toolkits/firecrawl/.pre-commit-config.yaml create mode 100644 toolkits/firecrawl/.ruff.toml create mode 100644 toolkits/firecrawl/LICENSE create mode 100644 toolkits/firecrawl/Makefile create mode 100644 toolkits/firecrawl/arcade_firecrawl/__init__.py create mode 100644 toolkits/firecrawl/arcade_firecrawl/enums.py create mode 100644 toolkits/firecrawl/arcade_firecrawl/tools/__init__.py create mode 100644 toolkits/firecrawl/arcade_firecrawl/tools/crawl.py create mode 100644 toolkits/firecrawl/arcade_firecrawl/tools/map.py create mode 100644 toolkits/firecrawl/arcade_firecrawl/tools/scrape.py create mode 100644 toolkits/firecrawl/evals/eval_firecrawl.py create mode 100644 toolkits/firecrawl/pyproject.toml create mode 100644 toolkits/firecrawl/tests/__init__.py create mode 100644 toolkits/firecrawl/tests/test_firecrawl.py create mode 100644 toolkits/gmail/.pre-commit-config.yaml create mode 100644 toolkits/gmail/.ruff.toml create mode 100644 toolkits/gmail/Makefile create mode 100644 toolkits/gmail/arcade_gmail/__init__.py create mode 100644 toolkits/gmail/arcade_gmail/constants.py create mode 100644 toolkits/gmail/arcade_gmail/enums.py create mode 100644 toolkits/gmail/arcade_gmail/exceptions.py create mode 100644 toolkits/gmail/arcade_gmail/tools/__init__.py create mode 100644 toolkits/gmail/arcade_gmail/tools/gmail.py create mode 100644 toolkits/gmail/arcade_gmail/utils.py create mode 100644 toolkits/gmail/evals/eval_google_gmail.py create mode 100644 toolkits/gmail/pyproject.toml create mode 100644 toolkits/gmail/tests/__init__.py create mode 100644 toolkits/gmail/tests/test_gmail.py create mode 100644 toolkits/google_calendar/.pre-commit-config.yaml create mode 100644 toolkits/google_calendar/.ruff.toml create mode 100644 toolkits/google_calendar/Makefile create mode 100644 toolkits/google_calendar/arcade_google_calendar/__init__.py create mode 100644 toolkits/google_calendar/arcade_google_calendar/enums.py create mode 100644 toolkits/google_calendar/arcade_google_calendar/tools/__init__.py create mode 100644 toolkits/google_calendar/arcade_google_calendar/tools/calendar.py create mode 100644 toolkits/google_calendar/arcade_google_calendar/utils.py create mode 100644 toolkits/google_calendar/evals/eval_google_calendar.py create mode 100644 toolkits/google_calendar/pyproject.toml create mode 100644 toolkits/google_calendar/tests/__init__.py create mode 100644 toolkits/google_calendar/tests/test_calendar.py create mode 100644 toolkits/google_contacts/.pre-commit-config.yaml create mode 100644 toolkits/google_contacts/.ruff.toml create mode 100644 toolkits/google_contacts/Makefile create mode 100644 toolkits/google_contacts/arcade_google_contacts/__init__.py create mode 100644 toolkits/google_contacts/arcade_google_contacts/constants.py create mode 100644 toolkits/google_contacts/arcade_google_contacts/tools/__init__.py create mode 100644 toolkits/google_contacts/arcade_google_contacts/tools/contacts.py create mode 100644 toolkits/google_contacts/arcade_google_contacts/utils.py create mode 100644 toolkits/google_contacts/evals/eval_google_contacts.py create mode 100644 toolkits/google_contacts/pyproject.toml create mode 100644 toolkits/google_contacts/tests/__init__.py create mode 100644 toolkits/google_contacts/tests/test_contacts.py create mode 100644 toolkits/google_docs/.pre-commit-config.yaml create mode 100644 toolkits/google_docs/.ruff.toml create mode 100644 toolkits/google_docs/Makefile create mode 100644 toolkits/google_docs/arcade_google_docs/__init__.py create mode 100644 toolkits/google_docs/arcade_google_docs/decorators.py create mode 100644 toolkits/google_docs/arcade_google_docs/doc_to_html.py create mode 100644 toolkits/google_docs/arcade_google_docs/doc_to_markdown.py create mode 100644 toolkits/google_docs/arcade_google_docs/enum.py create mode 100644 toolkits/google_docs/arcade_google_docs/file_picker.py create mode 100644 toolkits/google_docs/arcade_google_docs/templates.py create mode 100644 toolkits/google_docs/arcade_google_docs/tools/__init__.py create mode 100644 toolkits/google_docs/arcade_google_docs/tools/create.py create mode 100644 toolkits/google_docs/arcade_google_docs/tools/get.py create mode 100644 toolkits/google_docs/arcade_google_docs/tools/search.py create mode 100644 toolkits/google_docs/arcade_google_docs/tools/update.py create mode 100644 toolkits/google_docs/arcade_google_docs/utils.py create mode 100644 toolkits/google_docs/conftest.py create mode 100644 toolkits/google_docs/evals/eval_google_docs.py create mode 100644 toolkits/google_docs/pyproject.toml create mode 100644 toolkits/google_docs/tests/__init__.py create mode 100644 toolkits/google_docs/tests/test_doc_to_markdown.py create mode 100644 toolkits/google_docs/tests/test_google_docs.py create mode 100644 toolkits/google_docs/tests/test_search.py create mode 100644 toolkits/google_drive/.pre-commit-config.yaml create mode 100644 toolkits/google_drive/.ruff.toml create mode 100644 toolkits/google_drive/Makefile create mode 100644 toolkits/google_drive/arcade_google_drive/__init__.py create mode 100644 toolkits/google_drive/arcade_google_drive/enums.py create mode 100644 toolkits/google_drive/arcade_google_drive/templates.py create mode 100644 toolkits/google_drive/arcade_google_drive/tools/__init__.py create mode 100644 toolkits/google_drive/arcade_google_drive/tools/drive.py create mode 100644 toolkits/google_drive/arcade_google_drive/utils.py create mode 100644 toolkits/google_drive/conftest.py create mode 100644 toolkits/google_drive/evals/eval_google_drive.py create mode 100644 toolkits/google_drive/evals/eval_tools_understand_filepicker.py create mode 100644 toolkits/google_drive/pyproject.toml create mode 100644 toolkits/google_drive/tests/__init__.py create mode 100644 toolkits/google_drive/tests/test_drive.py create mode 100644 toolkits/google_finance/.pre-commit-config.yaml create mode 100644 toolkits/google_finance/.ruff.toml create mode 100644 toolkits/google_finance/LICENSE create mode 100644 toolkits/google_finance/Makefile create mode 100644 toolkits/google_finance/arcade_google_finance/__init__.py create mode 100644 toolkits/google_finance/arcade_google_finance/enums.py create mode 100644 toolkits/google_finance/arcade_google_finance/tools/__init__.py create mode 100644 toolkits/google_finance/arcade_google_finance/tools/google_finance.py create mode 100644 toolkits/google_finance/arcade_google_finance/utils.py create mode 100644 toolkits/google_finance/pyproject.toml create mode 100644 toolkits/google_flights/.pre-commit-config.yaml create mode 100644 toolkits/google_flights/.ruff.toml create mode 100644 toolkits/google_flights/LICENSE create mode 100644 toolkits/google_flights/Makefile create mode 100644 toolkits/google_flights/arcade_google_flights/__init__.py create mode 100644 toolkits/google_flights/arcade_google_flights/enums.py create mode 100644 toolkits/google_flights/arcade_google_flights/tools/__init__.py create mode 100644 toolkits/google_flights/arcade_google_flights/tools/google_flights.py create mode 100644 toolkits/google_flights/arcade_google_flights/utils.py create mode 100644 toolkits/google_flights/pyproject.toml create mode 100644 toolkits/google_hotels/.pre-commit-config.yaml create mode 100644 toolkits/google_hotels/.ruff.toml create mode 100644 toolkits/google_hotels/LICENSE create mode 100644 toolkits/google_hotels/Makefile create mode 100644 toolkits/google_hotels/arcade_google_hotels/__init__.py create mode 100644 toolkits/google_hotels/arcade_google_hotels/enums.py create mode 100644 toolkits/google_hotels/arcade_google_hotels/tools/__init__.py create mode 100644 toolkits/google_hotels/arcade_google_hotels/tools/google_hotels.py create mode 100644 toolkits/google_hotels/arcade_google_hotels/utils.py create mode 100644 toolkits/google_hotels/pyproject.toml create mode 100644 toolkits/google_jobs/.pre-commit-config.yaml create mode 100644 toolkits/google_jobs/.ruff.toml create mode 100644 toolkits/google_jobs/LICENSE create mode 100644 toolkits/google_jobs/Makefile create mode 100644 toolkits/google_jobs/arcade_google_jobs/__init__.py create mode 100644 toolkits/google_jobs/arcade_google_jobs/constants.py create mode 100644 toolkits/google_jobs/arcade_google_jobs/enums.py create mode 100644 toolkits/google_jobs/arcade_google_jobs/exceptions.py create mode 100644 toolkits/google_jobs/arcade_google_jobs/google_data.py create mode 100644 toolkits/google_jobs/arcade_google_jobs/tools/__init__.py create mode 100644 toolkits/google_jobs/arcade_google_jobs/tools/google_jobs.py create mode 100644 toolkits/google_jobs/arcade_google_jobs/utils.py create mode 100644 toolkits/google_jobs/evals/eval_google_jobs.py create mode 100644 toolkits/google_jobs/pyproject.toml create mode 100644 toolkits/google_jobs/tests/__init__.py create mode 100644 toolkits/google_jobs/tests/test_google_jobs.py create mode 100644 toolkits/google_maps/.pre-commit-config.yaml create mode 100644 toolkits/google_maps/.ruff.toml create mode 100644 toolkits/google_maps/LICENSE create mode 100644 toolkits/google_maps/Makefile create mode 100644 toolkits/google_maps/arcade_google_maps/__init__.py create mode 100644 toolkits/google_maps/arcade_google_maps/constants.py create mode 100644 toolkits/google_maps/arcade_google_maps/enums.py create mode 100644 toolkits/google_maps/arcade_google_maps/exceptions.py create mode 100644 toolkits/google_maps/arcade_google_maps/google_data.py create mode 100644 toolkits/google_maps/arcade_google_maps/tools/__init__.py create mode 100644 toolkits/google_maps/arcade_google_maps/tools/google_maps.py create mode 100644 toolkits/google_maps/arcade_google_maps/utils.py create mode 100644 toolkits/google_maps/evals/eval_google_maps_directions.py create mode 100644 toolkits/google_maps/pyproject.toml create mode 100644 toolkits/google_maps/tests/__init__.py create mode 100644 toolkits/google_maps/tests/test_google_maps_directions.py create mode 100644 toolkits/google_news/.pre-commit-config.yaml create mode 100644 toolkits/google_news/.ruff.toml create mode 100644 toolkits/google_news/LICENSE create mode 100644 toolkits/google_news/Makefile create mode 100644 toolkits/google_news/arcade_google_news/__init__.py create mode 100644 toolkits/google_news/arcade_google_news/constants.py create mode 100644 toolkits/google_news/arcade_google_news/exceptions.py create mode 100644 toolkits/google_news/arcade_google_news/google_data.py create mode 100644 toolkits/google_news/arcade_google_news/tools/__init__.py create mode 100644 toolkits/google_news/arcade_google_news/tools/google_news.py create mode 100644 toolkits/google_news/arcade_google_news/utils.py create mode 100644 toolkits/google_news/pyproject.toml create mode 100644 toolkits/google_search/.pre-commit-config.yaml create mode 100644 toolkits/google_search/.ruff.toml create mode 100644 toolkits/google_search/LICENSE create mode 100644 toolkits/google_search/Makefile create mode 100644 toolkits/google_search/arcade_google_search/__init__.py create mode 100644 toolkits/google_search/arcade_google_search/tools/__init__.py create mode 100644 toolkits/google_search/arcade_google_search/tools/google_search.py create mode 100644 toolkits/google_search/arcade_google_search/utils.py create mode 100644 toolkits/google_search/evals/eval_google_search.py create mode 100644 toolkits/google_search/pyproject.toml create mode 100644 toolkits/google_search/tests/__init__.py create mode 100644 toolkits/google_search/tests/test_google_search.py create mode 100644 toolkits/google_search/tests/test_utils.py create mode 100644 toolkits/google_sheets/.pre-commit-config.yaml create mode 100644 toolkits/google_sheets/.ruff.toml create mode 100644 toolkits/google_sheets/LICENSE create mode 100644 toolkits/google_sheets/Makefile create mode 100644 toolkits/google_sheets/arcade_google_sheets/__init__.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/constants.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/decorators.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/enums.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/file_picker.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/models.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/tools/__init__.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/tools/read.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/tools/write.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/types.py create mode 100644 toolkits/google_sheets/arcade_google_sheets/utils.py create mode 100644 toolkits/google_sheets/evals/eval_google_sheets.py create mode 100644 toolkits/google_sheets/pyproject.toml create mode 100644 toolkits/google_sheets/tests/__init__.py create mode 100644 toolkits/google_sheets/tests/test_sheets_models.py create mode 100644 toolkits/google_sheets/tests/test_sheets_utils.py create mode 100644 toolkits/google_shopping/.pre-commit-config.yaml create mode 100644 toolkits/google_shopping/.ruff.toml create mode 100644 toolkits/google_shopping/LICENSE create mode 100644 toolkits/google_shopping/Makefile create mode 100644 toolkits/google_shopping/arcade_google_shopping/__init__.py create mode 100644 toolkits/google_shopping/arcade_google_shopping/constants.py create mode 100644 toolkits/google_shopping/arcade_google_shopping/exceptions.py create mode 100644 toolkits/google_shopping/arcade_google_shopping/google_data.py create mode 100644 toolkits/google_shopping/arcade_google_shopping/tools/__init__.py create mode 100644 toolkits/google_shopping/arcade_google_shopping/tools/google_shopping.py create mode 100644 toolkits/google_shopping/arcade_google_shopping/utils.py create mode 100644 toolkits/google_shopping/pyproject.toml create mode 100644 toolkits/outlook_calendar/.pre-commit-config.yaml create mode 100644 toolkits/outlook_calendar/.ruff.toml create mode 100644 toolkits/outlook_calendar/LICENSE create mode 100644 toolkits/outlook_calendar/Makefile create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/__init__.py create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/_utils.py create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/client.py create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/constants.py create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/models.py create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/tools/__init__.py create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/tools/create_event.py create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/tools/get_event.py create mode 100644 toolkits/outlook_calendar/arcade_outlook_calendar/tools/list_events_in_time_range.py create mode 100644 toolkits/outlook_calendar/evals/additional_messages.py create mode 100644 toolkits/outlook_calendar/evals/eval_create_event.py create mode 100644 toolkits/outlook_calendar/evals/eval_get_event.py create mode 100644 toolkits/outlook_calendar/evals/eval_list_events_in_time_range.py create mode 100644 toolkits/outlook_calendar/pyproject.toml create mode 100644 toolkits/outlook_calendar/tests/__init__.py create mode 100644 toolkits/outlook_calendar/tests/test_models.py create mode 100644 toolkits/outlook_calendar/tests/test_utils.py create mode 100644 toolkits/outlook_mail/.pre-commit-config.yaml create mode 100644 toolkits/outlook_mail/.ruff.toml create mode 100644 toolkits/outlook_mail/LICENSE create mode 100644 toolkits/outlook_mail/Makefile create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/__init__.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/_utils.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/client.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/constants.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/enums.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/message.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/tools/__init__.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/tools/read.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/tools/send.py create mode 100644 toolkits/outlook_mail/arcade_outlook_mail/tools/write.py create mode 100644 toolkits/outlook_mail/evals/additional_messages.py create mode 100644 toolkits/outlook_mail/evals/eval_read.py create mode 100644 toolkits/outlook_mail/evals/eval_send.py create mode 100644 toolkits/outlook_mail/evals/eval_write.py create mode 100644 toolkits/outlook_mail/pyproject.toml create mode 100644 toolkits/outlook_mail/tests/__init__.py create mode 100644 toolkits/outlook_mail/tests/test_message.py create mode 100644 toolkits/outlook_mail/tests/test_recipient.py create mode 100644 toolkits/outlook_mail/tests/test_utils.py create mode 100644 toolkits/walmart/.pre-commit-config.yaml create mode 100644 toolkits/walmart/.ruff.toml create mode 100644 toolkits/walmart/LICENSE create mode 100644 toolkits/walmart/Makefile create mode 100644 toolkits/walmart/arcade_walmart/__init__.py create mode 100644 toolkits/walmart/arcade_walmart/enums.py create mode 100644 toolkits/walmart/arcade_walmart/tools/__init__.py create mode 100644 toolkits/walmart/arcade_walmart/tools/walmart.py create mode 100644 toolkits/walmart/arcade_walmart/utils.py create mode 100644 toolkits/walmart/pyproject.toml create mode 100644 toolkits/youtube/.pre-commit-config.yaml create mode 100644 toolkits/youtube/.ruff.toml create mode 100644 toolkits/youtube/LICENSE create mode 100644 toolkits/youtube/Makefile create mode 100644 toolkits/youtube/arcade_youtube/__init__.py create mode 100644 toolkits/youtube/arcade_youtube/constants.py create mode 100644 toolkits/youtube/arcade_youtube/exceptions.py create mode 100644 toolkits/youtube/arcade_youtube/google_data.py create mode 100644 toolkits/youtube/arcade_youtube/tools/__init__.py create mode 100644 toolkits/youtube/arcade_youtube/tools/youtube.py create mode 100644 toolkits/youtube/arcade_youtube/utils.py create mode 100644 toolkits/youtube/pyproject.toml diff --git a/.github/workflows/test-toolkits.yml b/.github/workflows/test-toolkits.yml index f991452b..8f085092 100644 --- a/.github/workflows/test-toolkits.yml +++ b/.github/workflows/test-toolkits.yml @@ -50,4 +50,13 @@ jobs: - name: Test toolkit working-directory: toolkits/${{ matrix.toolkit }} - run: uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml + run: | + # Run pytest and capture exit code + uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$? + + if [ "${EXIT_CODE:-0}" -eq 5 ]; then + echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..." + exit 0 + elif [ "${EXIT_CODE:-0}" -ne 0 ]; then + exit ${EXIT_CODE} + fi diff --git a/docker/toolkits.txt b/docker/toolkits.txt index b3f6bfd7..38dcb149 100644 --- a/docker/toolkits.txt +++ b/docker/toolkits.txt @@ -19,3 +19,23 @@ arcade-stripe arcade-web arcade-x arcade-zoom +arcade-e2b +arcade-firecrawl +arcade-gmail +arcade-google-calendar +arcade-google-contacts +arcade-google-docs +arcade-google-drive +arcade-google-finance +arcade-google-flights +arcade-google-hotels +arcade-google-jobs +arcade-google-maps +arcade-google-news +arcade-google-search +arcade-google-sheets +arcade-google-shopping +arcade-outlook-calendar +arcade-outlook-mail +arcade-walmart +arcade-youtube diff --git a/toolkits/e2b/.pre-commit-config.yaml b/toolkits/e2b/.pre-commit-config.yaml new file mode 100644 index 00000000..fd19ccd8 --- /dev/null +++ b/toolkits/e2b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/e2b/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/e2b/.ruff.toml b/toolkits/e2b/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/e2b/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/e2b/LICENSE b/toolkits/e2b/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/e2b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/e2b/Makefile b/toolkits/e2b/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/e2b/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/e2b/arcade_e2b/__init__.py b/toolkits/e2b/arcade_e2b/__init__.py new file mode 100644 index 00000000..4a13c163 --- /dev/null +++ b/toolkits/e2b/arcade_e2b/__init__.py @@ -0,0 +1,3 @@ +from arcade_e2b.tools import create_static_matplotlib_chart, run_code + +__all__ = ["create_static_matplotlib_chart", "run_code"] diff --git a/toolkits/e2b/arcade_e2b/enums.py b/toolkits/e2b/arcade_e2b/enums.py new file mode 100644 index 00000000..34ff05bc --- /dev/null +++ b/toolkits/e2b/arcade_e2b/enums.py @@ -0,0 +1,10 @@ +from enum import Enum + + +# Models and enums for the e2b code interpreter +class E2BSupportedLanguage(str, Enum): + PYTHON = "python" + JAVASCRIPT = "js" + R = "r" + JAVA = "java" + BASH = "bash" diff --git a/toolkits/e2b/arcade_e2b/tools/__init__.py b/toolkits/e2b/arcade_e2b/tools/__init__.py new file mode 100644 index 00000000..2ceb050e --- /dev/null +++ b/toolkits/e2b/arcade_e2b/tools/__init__.py @@ -0,0 +1,4 @@ +from arcade_e2b.tools.create_chart import create_static_matplotlib_chart +from arcade_e2b.tools.run_code import run_code + +__all__ = ["create_static_matplotlib_chart", "run_code"] diff --git a/toolkits/e2b/arcade_e2b/tools/create_chart.py b/toolkits/e2b/arcade_e2b/tools/create_chart.py new file mode 100644 index 00000000..003d8fc8 --- /dev/null +++ b/toolkits/e2b/arcade_e2b/tools/create_chart.py @@ -0,0 +1,31 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from e2b_code_interpreter import Sandbox + +# See https://e2b.dev/docs to learn more about E2B + + +# Note: Not recommended to use tool_choice='generate' with this tool +# since it contains base64 encoded image. +@tool(requires_secrets=["E2B_API_KEY"]) +def create_static_matplotlib_chart( + context: ToolContext, + code: Annotated[str, "The Python code to run"], +) -> Annotated[dict, "A dictionary with the following keys: base64_image, logs, error"]: + """ + Run the provided Python code to generate a static matplotlib chart. + The resulting chart is returned as a base64 encoded image. + """ + api_key = context.get_secret("E2B_API_KEY") + + with Sandbox(api_key=api_key) as sbx: + execution = sbx.run_code(code=code) + + result = { + "base64_image": execution.results[0].png if execution.results else None, + "logs": execution.logs.to_json(), + "error": execution.error.to_json() if execution.error else None, + } + + return result diff --git a/toolkits/e2b/arcade_e2b/tools/run_code.py b/toolkits/e2b/arcade_e2b/tools/run_code.py new file mode 100644 index 00000000..83ac729f --- /dev/null +++ b/toolkits/e2b/arcade_e2b/tools/run_code.py @@ -0,0 +1,27 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from e2b_code_interpreter import Sandbox + +from arcade_e2b.enums import E2BSupportedLanguage + +# See https://e2b.dev/docs to learn more about E2B + + +@tool(requires_secrets=["E2B_API_KEY"]) +def run_code( + context: ToolContext, + code: Annotated[str, "The code to run"], + language: Annotated[ + E2BSupportedLanguage, "The language of the code" + ] = E2BSupportedLanguage.PYTHON, +) -> Annotated[str, "The sandbox execution as a JSON string"]: + """ + Run code in a sandbox and return the output. + """ + api_key = context.get_secret("E2B_API_KEY") + + with Sandbox(api_key=api_key) as sbx: + execution = sbx.run_code(code=code, language=language) + + return str(execution.to_json()) diff --git a/toolkits/e2b/evals/eval_e2b.py b/toolkits/e2b/evals/eval_e2b.py new file mode 100644 index 00000000..1be1028d --- /dev/null +++ b/toolkits/e2b/evals/eval_e2b.py @@ -0,0 +1,120 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_e2b +from arcade_e2b.enums import E2BSupportedLanguage +from arcade_e2b.tools.create_chart import create_static_matplotlib_chart +from arcade_e2b.tools.run_code import run_code + +merge_sort_code = """ +def merge_sort(arr): + if len(arr) <= 1: + return arr + + mid = len(arr) // 2 + left = merge_sort(arr[:mid]) + right = merge_sort(arr[mid:]) + + return merge(left, right) + +def merge(left, right): + result = [] + i, j = 0, 0 + + while i < len(left) and j < len(right): + if left[i] < right[j]: + result.append(left[i]) + i += 1 + else: + result.append(right[j]) + j += 1 + + result.extend(left[i:]) + result.extend(right[j:]) + + return result + +sample_list = ["banana", "apple", "cherry", "date", "elderberry"] + +sorted_list = merge_sort(sample_list) +print("Sorted list:", sorted_list) +""" + +matplotlib_chart_code = """ +import matplotlib.pyplot as plt + +labels = ['Apples', 'Bananas', 'Cherries', 'Dates'] +sizes = [30, 25, 20, 25] +colors = ['red', 'yellow', 'purple', 'brown'] + +plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90) + +plt.axis('equal') + +plt.title('Fruit Distribution') + +plt.savefig('fruit_pie_chart.png') +""" + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.85, + warn_threshold=0.95, +) + + +catalog = ToolCatalog() +catalog.add_module(arcade_e2b) + + +@tool_eval() +def e2b_eval_suite(): + suite = EvalSuite( + name="E2B Tools Evaluation", + system_message="You are an AI assistant with access to E2B tools. Use them to help the user with their tasks.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Run code", + user_message=f"Can you please run my merge sort algo?\n\n{merge_sort_code}", + expected_tool_calls=[ + ExpectedToolCall( + func=run_code, + args={ + "code": merge_sort_code, + "language": E2BSupportedLanguage.PYTHON, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="code", weight=0.8), + BinaryCritic(critic_field="language", weight=0.2), + ], + ) + + suite.add_case( + name="Create static matplotlib chart", + user_message=f"Run this code:\n\n{matplotlib_chart_code}", + expected_tool_calls=[ + ExpectedToolCall( + func=create_static_matplotlib_chart, + args={ + "code": matplotlib_chart_code, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="code", weight=1.0), + ], + ) + + return suite diff --git a/toolkits/e2b/pyproject.toml b/toolkits/e2b/pyproject.toml new file mode 100644 index 00000000..e672d343 --- /dev/null +++ b/toolkits/e2b/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_e2b" +version = "2.0.0" +description = "Arcade.dev LLM tools for running code in a sandbox using E2B" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "e2b-code-interpreter>=1.0.1,<2.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "pytest-mock>=3.11.1,<3.12.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = {path = "../../", editable = true} +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } + +[tool.mypy] +files = [ "arcade_e2b/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_e2b",] diff --git a/toolkits/e2b/tests/__init__.py b/toolkits/e2b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/e2b/tests/test_e2b.py b/toolkits/e2b/tests/test_e2b.py new file mode 100644 index 00000000..3d8ee7d6 --- /dev/null +++ b/toolkits/e2b/tests/test_e2b.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock, patch + +import pytest +from arcade_tdk import ToolContext, ToolSecretItem +from arcade_tdk.errors import ToolExecutionError + +from arcade_e2b.enums import E2BSupportedLanguage +from arcade_e2b.tools.create_chart import create_static_matplotlib_chart +from arcade_e2b.tools.run_code import run_code + + +@pytest.fixture +def mock_run_code_sandbox(): + with patch("arcade_e2b.tools.run_code.Sandbox") as mock: + yield mock.return_value.__enter__.return_value + + +@pytest.fixture +def mock_create_chart_sandbox(): + with patch("arcade_e2b.tools.create_chart.Sandbox") as mock: + yield mock.return_value.__enter__.return_value + + +@pytest.fixture +def mock_context(): + return ToolContext(secrets=[ToolSecretItem(key="e2b_api_key", value="fake_api_key")]) + + +def test_run_code_success(mock_run_code_sandbox, mock_context): + mock_execution = MagicMock() + mock_execution.to_json.return_value = '{"result": "success"}' + mock_run_code_sandbox.run_code.return_value = mock_execution + + result = run_code(mock_context, "print('Hello, World!')", E2BSupportedLanguage.PYTHON) + assert result == '{"result": "success"}' + + +def test_run_code_error(mock_run_code_sandbox, mock_context): + mock_execution = MagicMock() + mock_execution.to_json.side_effect = ToolExecutionError("Execution failed") + mock_run_code_sandbox.run_code.return_value = mock_execution + + with pytest.raises(ToolExecutionError, match="Execution failed"): + run_code(mock_context, "print('Hello, World!')", E2BSupportedLanguage.PYTHON) + + +def test_create_static_matplotlib_chart_success(mock_create_chart_sandbox, mock_context): + mock_execution = MagicMock() + mock_execution.results = [MagicMock(png="base64encodedimage")] + mock_execution.logs.to_json.return_value = '{"logs": "log data"}' + mock_execution.error = None + mock_create_chart_sandbox.run_code.return_value = mock_execution + + result = create_static_matplotlib_chart(mock_context, "import matplotlib.pyplot as plt") + assert result == { + "base64_image": "base64encodedimage", + "logs": '{"logs": "log data"}', + "error": None, + } + + +def test_create_static_matplotlib_chart_error(mock_create_chart_sandbox, mock_context): + mock_execution = MagicMock() + mock_execution.results = [] + mock_execution.logs.to_json.return_value = '{"logs": "log data"}' + mock_execution.error.to_json.return_value = '{"error": "some error"}' + mock_create_chart_sandbox.run_code.return_value = mock_execution + + result = create_static_matplotlib_chart(mock_context, "import matplotlib.pyplot as plt") + assert result == { + "base64_image": None, + "logs": '{"logs": "log data"}', + "error": '{"error": "some error"}', + } diff --git a/toolkits/firecrawl/.pre-commit-config.yaml b/toolkits/firecrawl/.pre-commit-config.yaml new file mode 100644 index 00000000..031d52b9 --- /dev/null +++ b/toolkits/firecrawl/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/firecrawl/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/firecrawl/.ruff.toml b/toolkits/firecrawl/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/firecrawl/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/firecrawl/LICENSE b/toolkits/firecrawl/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/firecrawl/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/firecrawl/Makefile b/toolkits/firecrawl/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/firecrawl/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/firecrawl/arcade_firecrawl/__init__.py b/toolkits/firecrawl/arcade_firecrawl/__init__.py new file mode 100644 index 00000000..f0c5a745 --- /dev/null +++ b/toolkits/firecrawl/arcade_firecrawl/__init__.py @@ -0,0 +1,17 @@ +from arcade_firecrawl.tools import ( + cancel_crawl, + crawl_website, + get_crawl_data, + get_crawl_status, + map_website, + scrape_url, +) + +__all__ = [ + "cancel_crawl", + "crawl_website", + "get_crawl_data", + "get_crawl_status", + "map_website", + "scrape_url", +] diff --git a/toolkits/firecrawl/arcade_firecrawl/enums.py b/toolkits/firecrawl/arcade_firecrawl/enums.py new file mode 100644 index 00000000..2e823940 --- /dev/null +++ b/toolkits/firecrawl/arcade_firecrawl/enums.py @@ -0,0 +1,11 @@ +from enum import Enum + + +# Models and enums for firecrawl web tools +class Formats(str, Enum): + MARKDOWN = "markdown" + HTML = "html" + RAW_HTML = "rawHtml" + LINKS = "links" + SCREENSHOT = "screenshot" + SCREENSHOT_AT_FULL_PAGE = "screenshot@fullPage" diff --git a/toolkits/firecrawl/arcade_firecrawl/tools/__init__.py b/toolkits/firecrawl/arcade_firecrawl/tools/__init__.py new file mode 100644 index 00000000..dd3c5cbe --- /dev/null +++ b/toolkits/firecrawl/arcade_firecrawl/tools/__init__.py @@ -0,0 +1,17 @@ +from arcade_firecrawl.tools.crawl import ( + cancel_crawl, + crawl_website, + get_crawl_data, + get_crawl_status, +) +from arcade_firecrawl.tools.map import map_website +from arcade_firecrawl.tools.scrape import scrape_url + +__all__ = [ + "cancel_crawl", + "crawl_website", + "get_crawl_data", + "get_crawl_status", + "map_website", + "scrape_url", +] diff --git a/toolkits/firecrawl/arcade_firecrawl/tools/crawl.py b/toolkits/firecrawl/arcade_firecrawl/tools/crawl.py new file mode 100644 index 00000000..e7812e87 --- /dev/null +++ b/toolkits/firecrawl/arcade_firecrawl/tools/crawl.py @@ -0,0 +1,121 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from firecrawl import FirecrawlApp + + +# TODO: Support scrapeOptions. +@tool(requires_secrets=["FIRECRAWL_API_KEY"]) +async def crawl_website( + context: ToolContext, + url: Annotated[str, "URL to crawl"], + exclude_paths: Annotated[list[str] | None, "URL patterns to exclude from the crawl"] = None, + include_paths: Annotated[list[str] | None, "URL patterns to include in the crawl"] = None, + max_depth: Annotated[int, "Maximum depth to crawl relative to the entered URL"] = 2, + ignore_sitemap: Annotated[bool, "Ignore the website sitemap when crawling"] = True, + limit: Annotated[int, "Limit the number of pages to crawl"] = 10, + allow_backward_links: Annotated[ + bool, + "Enable navigation to previously linked pages and enable crawling " + "sublinks that are not children of the 'url' input parameter.", + ] = False, + allow_external_links: Annotated[bool, "Allow following links to external websites"] = False, + webhook: Annotated[ + str | None, + "The URL to send a POST request to when the crawl is started, updated and completed.", + ] = None, + async_crawl: Annotated[bool, "Run the crawl asynchronously"] = True, +) -> Annotated[dict[str, Any], "Crawl status and data"]: + """ + Crawl a website using Firecrawl. If the crawl is asynchronous, then returns the crawl ID. + If the crawl is synchronous, then returns the crawl data. + """ + + api_key = context.get_secret("FIRECRAWL_API_KEY") + + app = FirecrawlApp(api_key=api_key) + params = { + "limit": limit, + "excludePaths": exclude_paths or [], + "includePaths": include_paths or [], + "maxDepth": max_depth, + "ignoreSitemap": ignore_sitemap, + "allowBackwardLinks": allow_backward_links, + "allowExternalLinks": allow_external_links, + } + if webhook: + params["webhook"] = webhook + + if async_crawl: + response = app.async_crawl_url(url, params=params) + response.pop("url", None) # Remove 'url' as it's an API endpoint + + if response["success"]: + response["status"] = await get_crawl_status(context, response["id"]) + response["llm_instructions"] = ( + "You have the ability to get crawl status, cancel a crawl, " + "and get a crawl's data. Inform the user that you have these capabilities. " + "Inform the user that they should let you know if they want you to perform any " + "of these actions." + ) + + else: + response = app.crawl_url(url, params=params) + + return dict(response) + + +@tool(requires_secrets=["FIRECRAWL_API_KEY"]) +async def get_crawl_status( + context: ToolContext, + crawl_id: Annotated[str, "The ID of the crawl job"], +) -> Annotated[dict[str, Any], "Crawl status information"]: + """ + Get the status of a Firecrawl 'crawl' that is either in progress or recently completed. + """ + + api_key = context.get_secret("FIRECRAWL_API_KEY") + + app = FirecrawlApp(api_key=api_key) + crawl_status = app.check_crawl_status(crawl_id) + + crawl_status.pop("data", None) # Remove 'data' if it exists + crawl_status.pop("next", None) # Remove 'next' as it's an API endpoint + + return dict(crawl_status) + + +# TODO: Support responses greater than 10 MB. If the response is greater than 10 MB, +# then the Firecrawl API response will have a next_url field. +@tool(requires_secrets=["FIRECRAWL_API_KEY"]) +async def get_crawl_data( + context: ToolContext, + crawl_id: Annotated[str, "The ID of the crawl job"], +) -> Annotated[dict[str, Any], "Crawl data information"]: + """ + Get the data of a Firecrawl 'crawl' that is either in progress or recently completed. + """ + + api_key = context.get_secret("FIRECRAWL_API_KEY") + + app = FirecrawlApp(api_key=api_key) + crawl_data = app.check_crawl_status(crawl_id) + + return dict(crawl_data) + + +@tool(requires_secrets=["FIRECRAWL_API_KEY"]) +async def cancel_crawl( + context: ToolContext, + crawl_id: Annotated[str, "The ID of the asynchronous crawl job to cancel"], +) -> Annotated[dict[str, Any], "Cancellation status information"]: + """ + Cancel an asynchronous crawl job that is in progress using the Firecrawl API. + """ + + api_key = context.get_secret("FIRECRAWL_API_KEY") + + app = FirecrawlApp(api_key=api_key) + cancellation_status = app.cancel_crawl(crawl_id) + + return dict(cancellation_status) diff --git a/toolkits/firecrawl/arcade_firecrawl/tools/map.py b/toolkits/firecrawl/arcade_firecrawl/tools/map.py new file mode 100644 index 00000000..32d460af --- /dev/null +++ b/toolkits/firecrawl/arcade_firecrawl/tools/map.py @@ -0,0 +1,33 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from firecrawl import FirecrawlApp + + +@tool(requires_secrets=["FIRECRAWL_API_KEY"]) +async def map_website( + context: ToolContext, + url: Annotated[str, "The base URL to start crawling from"], + search: Annotated[str | None, "Search query to use for mapping"] = None, + ignore_sitemap: Annotated[bool, "Ignore the website sitemap when crawling"] = True, + include_subdomains: Annotated[bool, "Include subdomains of the website"] = False, + limit: Annotated[int, "Maximum number of links to return"] = 5000, +) -> Annotated[dict[str, Any], "Website map data"]: + """ + Map a website from a single URL to a map of the entire website. + """ + + api_key = context.get_secret("FIRECRAWL_API_KEY") + + app = FirecrawlApp(api_key=api_key) + params: dict[str, Any] = { + "ignoreSitemap": ignore_sitemap, + "includeSubdomains": include_subdomains, + "limit": limit, + } + if search: + params["search"] = search + + map_result = app.map_url(url, params=params) + + return dict(map_result) diff --git a/toolkits/firecrawl/arcade_firecrawl/tools/scrape.py b/toolkits/firecrawl/arcade_firecrawl/tools/scrape.py new file mode 100644 index 00000000..148b98d3 --- /dev/null +++ b/toolkits/firecrawl/arcade_firecrawl/tools/scrape.py @@ -0,0 +1,49 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from firecrawl import FirecrawlApp + +from arcade_firecrawl.enums import Formats + + +# TODO: Support actions. This would enable clicking, scrolling, screenshotting, etc. +# TODO: Support extract. +# TODO: Support headers param? +@tool(requires_secrets=["FIRECRAWL_API_KEY"]) +async def scrape_url( + context: ToolContext, + url: Annotated[str, "URL to scrape"], + formats: Annotated[ + list[Formats] | None, "Formats to retrieve. Defaults to ['markdown']." + ] = None, + only_main_content: Annotated[ + bool | None, + "Only return the main content of the page excluding headers, navs, footers, etc.", + ] = True, + include_tags: Annotated[list[str] | None, "List of tags to include in the output"] = None, + exclude_tags: Annotated[list[str] | None, "List of tags to exclude from the output"] = None, + wait_for: Annotated[ + int | None, + "Specify a delay in milliseconds before fetching the content, allowing the page " + "sufficient time to load.", + ] = 10, + timeout: Annotated[int | None, "Timeout in milliseconds for the request"] = 30000, +) -> Annotated[dict[str, Any], "Scraped data in specified formats"]: + """Scrape a URL using Firecrawl and return the data in specified formats.""" + + api_key = context.get_secret("FIRECRAWL_API_KEY") + + formats = formats or [Formats.MARKDOWN] + + app = FirecrawlApp(api_key=api_key) + params = { + "formats": formats, + "onlyMainContent": only_main_content, + "includeTags": include_tags or [], + "excludeTags": exclude_tags or [], + "waitFor": wait_for, + "timeout": timeout, + } + response = app.scrape_url(url, params=params) + + return dict(response) diff --git a/toolkits/firecrawl/evals/eval_firecrawl.py b/toolkits/firecrawl/evals/eval_firecrawl.py new file mode 100644 index 00000000..cbaa7fd6 --- /dev/null +++ b/toolkits/firecrawl/evals/eval_firecrawl.py @@ -0,0 +1,244 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + NumericCritic, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_firecrawl +from arcade_firecrawl.tools import ( + cancel_crawl, + crawl_website, + get_crawl_data, + get_crawl_status, + map_website, + scrape_url, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +# Register the Firecrawl tools +catalog.add_module(arcade_firecrawl) + + +@tool_eval() +def firecrawl_eval_suite() -> EvalSuite: + """Evaluation suite for Firecrawl tools.""" + suite = EvalSuite( + name="Firecrawl Tools Evaluation Suite", + system_message="You are an AI assistant that helps users interact with web scraping and crawling tools using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + # Scrape URL + suite.add_case( + name="Scrape a URL", + user_message="Scrape https://foobar.com/howto/tutorials/join-discord-server in markdown format please. Wait for 10 seconds before fetching the content.", + expected_tool_calls=[ + ExpectedToolCall( + func=scrape_url, + args={ + "url": "https://foobar.com/howto/tutorials/join-discord-server", + "formats": ["markdown"], + "wait_for": 10000, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="url", weight=0.4), + BinaryCritic(critic_field="formats", weight=0.4), + NumericCritic(critic_field="wait_for", weight=0.2, value_range=(9000, 11000)), + ], + ) + + # Crawl Website + suite.add_case( + name="Crawl a website", + user_message="Crawl the website at https://wikipedia.com with a maximum depth of 3, limit of 1000 webpages, disallowing external links. Updates should be sent to http://example.com/crawl-updates. Oh and do it in the background. THanks", + expected_tool_calls=[ + ExpectedToolCall( + func=crawl_website, + args={ + "url": "https://wikipedia.com", + "max_depth": 3, + "limit": 1000, + "allow_external_links": False, + "webhook": "http://example.com/crawl-updates", + "async_crawl": True, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="url", weight=0.2), + BinaryCritic(critic_field="max_depth", weight=0.1), + BinaryCritic(critic_field="limit", weight=0.1), + BinaryCritic(critic_field="allow_external_links", weight=0.1), + BinaryCritic(critic_field="webhook", weight=0.2), + BinaryCritic(critic_field="async_crawl", weight=0.2), + ], + ) + + # Get Crawl Status + suite.add_case( + name="Get crawl status", + user_message="Check the status of my crawl", + expected_tool_calls=[ + ExpectedToolCall( + func=get_crawl_status, + args={ + "crawl_id": "2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="crawl_id", weight=1.0), + ], + additional_messages=[ + {"role": "user", "content": "crawl asynchronously https://www.google.com"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_QklpRSDmHdvM3ZZfzOqCKWRN", + "type": "function", + "function": { + "name": "Firecrawl_CrawlWebsite", + "arguments": '{"url":"https://www.google.com","async_crawl":true}', + }, + } + ], + }, + { + "role": "tool", + "content": '{"id":"2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b","success":true,"url":"https://api.firecrawl.dev/v1/crawl/2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b"}', + "tool_call_id": "call_QklpRSDmHdvM3ZZfzOqCKWRN", + "name": "Firecrawl_CrawlWebsite", + }, + { + "role": "assistant", + "content": "The asynchronous web crawl request for [Google](https://www.google.com) has been successfully initiated. You can track the status or fetch the results using the following [link](https://api.firecrawl.dev/v1/crawl/2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b).", + }, + ], + ) + + # # Get Crawl Data + suite.add_case( + name="Get crawl status", + user_message="Ok looks like the crawl is done, can I get the result please?", + expected_tool_calls=[ + ExpectedToolCall( + func=get_crawl_data, + args={ + "crawl_id": "2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="crawl_id", weight=1.0), + ], + additional_messages=[ + {"role": "user", "content": "crawl asynchronously https://www.google.com"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_QklpRSDmHdvM3ZZfzOqCKWRN", + "type": "function", + "function": { + "name": "Firecrawl_CrawlWebsite", + "arguments": '{"url":"https://www.google.com","async_crawl":true}', + }, + } + ], + }, + { + "role": "tool", + "content": '{"id":"2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b","success":true,"url":"https://api.firecrawl.dev/v1/crawl/2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b"}', + "tool_call_id": "call_QklpRSDmHdvM3ZZfzOqCKWRN", + "name": "Firecrawl_CrawlWebsite", + }, + { + "role": "assistant", + "content": "The asynchronous web crawl request for [Google](https://www.google.com) has been successfully initiated. You can track the status or fetch the results using the following [link](https://api.firecrawl.dev/v1/crawl/2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b).", + }, + ], + ) + + # Cancel Crawl + suite.add_case( + name="Get crawl status", + user_message="Actually cancel it.", + expected_tool_calls=[ + ExpectedToolCall( + func=cancel_crawl, + args={ + "crawl_id": "2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="crawl_id", weight=1.0), + ], + additional_messages=[ + {"role": "user", "content": "crawl asynchronously https://www.google.com"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_QklpRSDmHdvM3ZZfzOqCKWRN", + "type": "function", + "function": { + "name": "Firecrawl_CrawlWebsite", + "arguments": '{"url":"https://www.google.com","async_crawl":true}', + }, + } + ], + }, + { + "role": "tool", + "content": '{"id":"2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b","success":true,"url":"https://api.firecrawl.dev/v1/crawl/2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b"}', + "tool_call_id": "call_QklpRSDmHdvM3ZZfzOqCKWRN", + "name": "Firecrawl_CrawlWebsite", + }, + { + "role": "assistant", + "content": "The asynchronous web crawl request for [Google](https://www.google.com) has been successfully initiated. You can track the status or fetch the results using the following [link](https://api.firecrawl.dev/v1/crawl/2ee7ba77-4ba0-4a45-9e2f-1c9e9a56f29b).", + }, + ], + ) + + # Map Website + suite.add_case( + name="Map a website", + user_message="Map the website at https://wikipedia.com with a limit of 100000 links. Only the links that are about the topic of AI", + expected_tool_calls=[ + ExpectedToolCall( + func=map_website, + args={ + "url": "https://wikipedia.com", + "search": "AI", + "limit": 100000, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="url", weight=0.4), + SimilarityCritic(critic_field="search", weight=0.2), + NumericCritic(critic_field="limit", weight=0.4, value_range=(90000, 110000)), + ], + ) + + return suite diff --git a/toolkits/firecrawl/pyproject.toml b/toolkits/firecrawl/pyproject.toml new file mode 100644 index 00000000..79b3790d --- /dev/null +++ b/toolkits/firecrawl/pyproject.toml @@ -0,0 +1,54 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_firecrawl" +version = "2.0.0" +description = "Arcade.dev LLM tools for web scraping related tasks via Firecrawl" +requires-python = ">=3.10" +dependencies = [ "arcade-tdk>=2.0.0,<3.0.0", "firecrawl-py>=1.3.1,<2.0.0",] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "pytest-mock>=3.11.1,<3.12.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = {path = "../../", editable = true} +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } + +[tool.mypy] +files = [ "arcade_firecrawl/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_firecrawl",] diff --git a/toolkits/firecrawl/tests/__init__.py b/toolkits/firecrawl/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/firecrawl/tests/test_firecrawl.py b/toolkits/firecrawl/tests/test_firecrawl.py new file mode 100644 index 00000000..ca248855 --- /dev/null +++ b/toolkits/firecrawl/tests/test_firecrawl.py @@ -0,0 +1,129 @@ +from unittest.mock import patch + +import pytest +from arcade_tdk import ToolContext, ToolSecretItem +from arcade_tdk.errors import ToolExecutionError + +from arcade_firecrawl.tools import ( + cancel_crawl, + crawl_website, + get_crawl_data, + get_crawl_status, + map_website, + scrape_url, +) + + +@pytest.fixture +def mock_context(): + return ToolContext(secrets=[ToolSecretItem(key="firecrawl_api_key", value="fake_api_key")]) + + +@pytest.fixture +def mock_firecrawl_app_for_scrape(): + with patch("arcade_firecrawl.tools.scrape.FirecrawlApp") as app: + yield app.return_value + + +@pytest.fixture +def mock_firecrawl_app_for_crawl(): + with patch("arcade_firecrawl.tools.crawl.FirecrawlApp") as app: + yield app.return_value + + +@pytest.fixture +def mock_firecrawl_app_for_map(): + with patch("arcade_firecrawl.tools.map.FirecrawlApp") as app: + yield app.return_value + + +@pytest.mark.asyncio +async def test_scrape_url_success(mock_firecrawl_app_for_scrape, mock_context): + expected_response = { + "success": True, + "data": {"scraped_content": "scraped content"}, + } + mock_firecrawl_app_for_scrape.scrape_url.return_value = expected_response + + result = await scrape_url(mock_context, "http://example.com") + assert result == expected_response + + +@pytest.mark.asyncio +async def test_crawl_website_success(mock_firecrawl_app_for_crawl, mock_context): + expected_response = { + "id": "12345", + "success": True, + } + mock_firecrawl_app_for_crawl.async_crawl_url.return_value = expected_response + mock_firecrawl_app_for_crawl.check_crawl_status.return_value = expected_response + + result = await crawl_website(mock_context, "http://example.com") + assert result == expected_response + + +@pytest.mark.asyncio +async def test_get_crawl_status_success(mock_firecrawl_app_for_crawl, mock_context): + expected_response = {"status": "completed"} + mock_firecrawl_app_for_crawl.check_crawl_status.return_value = expected_response + + result = await get_crawl_status(mock_context, "12345") + assert result == expected_response + + +@pytest.mark.asyncio +async def test_get_crawl_data_success(mock_firecrawl_app_for_crawl, mock_context): + expected_response = {"data": "crawl data"} + mock_firecrawl_app_for_crawl.check_crawl_status.return_value = expected_response + + result = await get_crawl_data(mock_context, "12345") + assert result == expected_response + + +@pytest.mark.asyncio +async def test_cancel_crawl_success(mock_firecrawl_app_for_crawl, mock_context): + expected_response = {"status": "cancelled"} + mock_firecrawl_app_for_crawl.cancel_crawl.return_value = expected_response + + result = await cancel_crawl(mock_context, "12345") + assert result == expected_response + + +@pytest.mark.asyncio +async def test_map_website_success(mock_firecrawl_app_for_map, mock_context): + expected_response = {"map": "website map"} + mock_firecrawl_app_for_map.map_url.return_value = expected_response + + result = await map_website(mock_context, "http://example.com") + assert result == expected_response + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "method,params,error_message", + [ + (scrape_url, ("http://example.com",), "Error scraping URL"), + (crawl_website, ("http://example.com",), "Error crawling website"), + (get_crawl_status, ("12345",), "Error getting crawl status"), + (get_crawl_data, ("12345",), "Error getting crawl data"), + (cancel_crawl, ("12345",), "Error cancelling crawl"), + (map_website, ("http://example.com",), "Error mapping website"), + ], +) +async def test_firecrawl_error( + mock_firecrawl_app_for_scrape, + mock_firecrawl_app_for_crawl, + mock_firecrawl_app_for_map, + mock_context, + method, + params, + error_message, +): + mock_firecrawl_app_for_scrape.scrape_url.side_effect = Exception(error_message) + mock_firecrawl_app_for_crawl.async_crawl_url.side_effect = Exception(error_message) + mock_firecrawl_app_for_crawl.check_crawl_status.side_effect = Exception(error_message) + mock_firecrawl_app_for_crawl.cancel_crawl.side_effect = Exception(error_message) + mock_firecrawl_app_for_map.map_url.side_effect = Exception(error_message) + + with pytest.raises(ToolExecutionError): + await method(mock_context, *params) diff --git a/toolkits/gmail/.pre-commit-config.yaml b/toolkits/gmail/.pre-commit-config.yaml new file mode 100644 index 00000000..d46cb6e2 --- /dev/null +++ b/toolkits/gmail/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/gmail/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/gmail/.ruff.toml b/toolkits/gmail/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/gmail/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/gmail/Makefile b/toolkits/gmail/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/gmail/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/gmail/arcade_gmail/__init__.py b/toolkits/gmail/arcade_gmail/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/gmail/arcade_gmail/constants.py b/toolkits/gmail/arcade_gmail/constants.py new file mode 100644 index 00000000..4755797e --- /dev/null +++ b/toolkits/gmail/arcade_gmail/constants.py @@ -0,0 +1,18 @@ +import os + +from arcade_gmail.enums import GmailReplyToWhom + +# The default reply in Gmail is to only the sender. Since Gmail also offers the possibility of +# changing the default to 'reply to all', we support both options through an env variable. +# https://support.google.com/mail/answer/6585?hl=en&sjid=15399867888091633568-SA#null +try: + GMAIL_DEFAULT_REPLY_TO = GmailReplyToWhom( + # Values accepted are defined in the arcade_google.tools.models.GmailReplyToWhom Enum + os.getenv("ARCADE_GMAIL_DEFAULT_REPLY_TO", GmailReplyToWhom.ONLY_THE_SENDER.value).lower() + ) +except ValueError as e: + raise ValueError( + "Invalid value for ARCADE_GMAIL_DEFAULT_REPLY_TO: " + f"'{os.getenv('ARCADE_GMAIL_DEFAULT_REPLY_TO')}'. Expected one of " + f"{list(GmailReplyToWhom.__members__.keys())}" + ) from e diff --git a/toolkits/gmail/arcade_gmail/enums.py b/toolkits/gmail/arcade_gmail/enums.py new file mode 100644 index 00000000..16b186d7 --- /dev/null +++ b/toolkits/gmail/arcade_gmail/enums.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class GmailReplyToWhom(str, Enum): + EVERY_RECIPIENT = "every_recipient" + ONLY_THE_SENDER = "only_the_sender" + + +class GmailAction(str, Enum): + SEND = "send" + DRAFT = "draft" diff --git a/toolkits/gmail/arcade_gmail/exceptions.py b/toolkits/gmail/arcade_gmail/exceptions.py new file mode 100644 index 00000000..12c9aa65 --- /dev/null +++ b/toolkits/gmail/arcade_gmail/exceptions.py @@ -0,0 +1,19 @@ +class GmailToolError(Exception): + """Base exception for Google tool errors.""" + + def __init__(self, message: str, developer_message: str | None = None): + self.message = message + self.developer_message = developer_message + super().__init__(self.message) + + def __str__(self) -> str: + base_message = self.message + if self.developer_message: + return f"{base_message} (Developer: {self.developer_message})" + return base_message + + +class GmailServiceError(GmailToolError): + """Raised when there's an error building or using the Google service.""" + + pass diff --git a/toolkits/gmail/arcade_gmail/tools/__init__.py b/toolkits/gmail/arcade_gmail/tools/__init__.py new file mode 100644 index 00000000..6ed7ba8e --- /dev/null +++ b/toolkits/gmail/arcade_gmail/tools/__init__.py @@ -0,0 +1,39 @@ +from arcade_gmail.tools.gmail import ( + change_email_labels, + create_label, + delete_draft_email, + get_thread, + list_draft_emails, + list_emails, + list_emails_by_header, + list_labels, + list_threads, + reply_to_email, + search_threads, + send_draft_email, + send_email, + trash_email, + update_draft_email, + write_draft_email, + write_draft_reply_email, +) + +__all__ = [ + "change_email_labels", + "create_label", + "delete_draft_email", + "get_thread", + "list_draft_emails", + "list_emails", + "list_emails_by_header", + "list_labels", + "list_threads", + "reply_to_email", + "search_threads", + "send_draft_email", + "send_email", + "trash_email", + "update_draft_email", + "write_draft_email", + "write_draft_reply_email", +] diff --git a/toolkits/gmail/arcade_gmail/tools/gmail.py b/toolkits/gmail/arcade_gmail/tools/gmail.py new file mode 100644 index 00000000..23a87572 --- /dev/null +++ b/toolkits/gmail/arcade_gmail/tools/gmail.py @@ -0,0 +1,664 @@ +import base64 +from email.mime.text import MIMEText +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Google +from arcade_tdk.errors import RetryableToolError +from googleapiclient.errors import HttpError + +from arcade_gmail.constants import GMAIL_DEFAULT_REPLY_TO +from arcade_gmail.enums import GmailAction, GmailReplyToWhom +from arcade_gmail.exceptions import GmailToolError +from arcade_gmail.utils import ( + DateRange, + _build_gmail_service, + build_email_message, + build_gmail_query_string, + build_reply_recipients, + fetch_messages, + get_draft_url, + get_email_details, + get_email_in_trash_url, + get_label_ids, + get_sent_email_url, + parse_draft_email, + parse_multipart_email, + parse_plain_text_email, + remove_none_values, +) + + +# Email sending tools +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.send"], + ) +) +async def send_email( + context: ToolContext, + subject: Annotated[str, "The subject of the email"], + body: Annotated[str, "The body of the email"], + recipient: Annotated[str, "The recipient of the email"], + cc: Annotated[list[str] | None, "CC recipients of the email"] = None, + bcc: Annotated[list[str] | None, "BCC recipients of the email"] = None, +) -> Annotated[dict, "A dictionary containing the sent email details"]: + """ + Send an email using the Gmail API. + """ + service = _build_gmail_service(context) + email = build_email_message(recipient, subject, body, cc, bcc) + + sent_message = service.users().messages().send(userId="me", body=email).execute() + + email = parse_plain_text_email(sent_message) + email["url"] = get_sent_email_url(sent_message["id"]) + return email + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.send"], + ) +) +async def send_draft_email( + context: ToolContext, email_id: Annotated[str, "The ID of the draft to send"] +) -> Annotated[dict, "A dictionary containing the sent email details"]: + """ + Send a draft email using the Gmail API. + """ + + service = _build_gmail_service(context) + + # Send the draft email + sent_message = service.users().drafts().send(userId="me", body={"id": email_id}).execute() + + email = parse_plain_text_email(sent_message) + email["url"] = get_sent_email_url(sent_message["id"]) + return email + + +# Note: in the Gmail UI, a user can customize the recipient and cc fields before replying. +# We decided not to support this feature, since we'd need a way for LLMs to tell apart between +# adding or removing recipients/cc, or replacing with an entirely new list of addresses, +# which would make the tool more complex to call. +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.send"], + ) +) +async def reply_to_email( + context: ToolContext, + body: Annotated[str, "The body of the email"], + reply_to_message_id: Annotated[str, "The ID of the message to reply to"], + reply_to_whom: Annotated[ + GmailReplyToWhom, + "Whether to reply to every recipient (including cc) or only to the original sender. " + f"Defaults to '{GMAIL_DEFAULT_REPLY_TO}'.", + ] = GMAIL_DEFAULT_REPLY_TO, + bcc: Annotated[list[str] | None, "BCC recipients of the email"] = None, +) -> Annotated[dict, "A dictionary containing the sent email details"]: + """ + Send a reply to an email message. + """ + if isinstance(reply_to_whom, str): + reply_to_whom = GmailReplyToWhom(reply_to_whom) + + service = _build_gmail_service(context) + + current_user = service.users().getProfile(userId="me").execute() + + try: + replying_to_email = ( + service.users().messages().get(userId="me", id=reply_to_message_id).execute() + ) + except HttpError as e: + raise RetryableToolError( + message=f"Could not retrieve the message with id {reply_to_message_id}.", + developer_message=( + f"Could not retrieve the message with id {reply_to_message_id}. " + f"Reason: '{e.reason}'. Error details: '{e.error_details}'" + ), + ) from e + + replying_to_email = parse_multipart_email(replying_to_email) + + recipients = build_reply_recipients( + replying_to_email, current_user["emailAddress"], reply_to_whom + ) + + email = build_email_message( + recipient=recipients, + subject=f"Re: {replying_to_email['subject']}", + body=body, + cc=None + if reply_to_whom == GmailReplyToWhom.ONLY_THE_SENDER + else replying_to_email["cc"].split(","), + bcc=bcc, + replying_to=replying_to_email, + ) + + sent_message = service.users().messages().send(userId="me", body=email).execute() + + email = parse_plain_text_email(sent_message) + email["url"] = get_sent_email_url(sent_message["id"]) + return email + + +# Draft Management Tools +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.compose"], + ) +) +async def write_draft_email( + context: ToolContext, + subject: Annotated[str, "The subject of the draft email"], + body: Annotated[str, "The body of the draft email"], + recipient: Annotated[str, "The recipient of the draft email"], + cc: Annotated[list[str] | None, "CC recipients of the draft email"] = None, + bcc: Annotated[list[str] | None, "BCC recipients of the draft email"] = None, +) -> Annotated[dict, "A dictionary containing the created draft email details"]: + """ + Compose a new email draft using the Gmail API. + """ + # Set up the Gmail API client + service = _build_gmail_service(context) + + draft = { + "message": build_email_message(recipient, subject, body, cc, bcc, action=GmailAction.DRAFT) + } + + draft_message = service.users().drafts().create(userId="me", body=draft).execute() + email = parse_draft_email(draft_message) + email["url"] = get_draft_url(draft_message["id"]) + return email + + +# Note: in the Gmail UI, a user can customize the recipient and cc fields before replying. +# We decided not to support this feature, since we'd need a way for LLMs to tell apart between +# adding or removing recipients/cc, or replacing with an entirely new list of addresses, +# which would make the tool more complex to call. +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.compose"], + ) +) +async def write_draft_reply_email( + context: ToolContext, + body: Annotated[str, "The body of the draft reply email"], + reply_to_message_id: Annotated[str, "The Gmail message ID of the message to draft a reply to"], + reply_to_whom: Annotated[ + GmailReplyToWhom, + "Whether to reply to every recipient (including cc) or only to the original sender. " + f"Defaults to '{GMAIL_DEFAULT_REPLY_TO}'.", + ] = GMAIL_DEFAULT_REPLY_TO, + bcc: Annotated[list[str] | None, "BCC recipients of the draft reply email"] = None, +) -> Annotated[dict, "A dictionary containing the created draft reply email details"]: + """ + Compose a draft reply to an email message. + """ + if isinstance(reply_to_whom, str): + reply_to_whom = GmailReplyToWhom(reply_to_whom) + + service = _build_gmail_service(context) + + current_user = service.users().getProfile(userId="me").execute() + + try: + replying_to_email = ( + service.users().messages().get(userId="me", id=reply_to_message_id).execute() + ) + except HttpError as e: + raise RetryableToolError( + message="Could not retrieve the message to respond to.", + developer_message=( + "Could not retrieve the message to respond to. " + f"Reason: '{e.reason}'. Error details: '{e.error_details}'" + ), + ) + + replying_to_email = parse_multipart_email(replying_to_email) + + recipients = build_reply_recipients( + replying_to_email, current_user["emailAddress"], reply_to_whom + ) + + draft_message = { + "message": build_email_message( + recipient=recipients, + subject=f"Re: {replying_to_email['subject']}", + body=body, + cc=None + if reply_to_whom == GmailReplyToWhom.ONLY_THE_SENDER + else replying_to_email["cc"].split(","), + bcc=bcc, + replying_to=replying_to_email, + action=GmailAction.DRAFT, + ), + } + + draft = service.users().drafts().create(userId="me", body=draft_message).execute() + + email = parse_draft_email(draft) + email["url"] = get_draft_url(draft["id"]) + return email + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.compose"], + ) +) +async def update_draft_email( + context: ToolContext, + draft_email_id: Annotated[str, "The ID of the draft email to update."], + subject: Annotated[str, "The subject of the draft email"], + body: Annotated[str, "The body of the draft email"], + recipient: Annotated[str, "The recipient of the draft email"], + cc: Annotated[list[str] | None, "CC recipients of the draft email"] = None, + bcc: Annotated[list[str] | None, "BCC recipients of the draft email"] = None, +) -> Annotated[dict, "A dictionary containing the updated draft email details"]: + """ + Update an existing email draft using the Gmail API. + """ + service = _build_gmail_service(context) + + message = MIMEText(body) + message["to"] = recipient + message["subject"] = subject + if cc: + message["Cc"] = ", ".join(cc) + if bcc: + message["Bcc"] = ", ".join(bcc) + + # Encode the message in base64 + raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode() + + # Update the draft + draft = {"id": draft_email_id, "message": {"raw": raw_message}} + + updated_draft_message = ( + service.users().drafts().update(userId="me", id=draft_email_id, body=draft).execute() + ) + + email = parse_draft_email(updated_draft_message) + email["url"] = get_draft_url(updated_draft_message["id"]) + + return email + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.compose"], + ) +) +async def delete_draft_email( + context: ToolContext, + draft_email_id: Annotated[str, "The ID of the draft email to delete"], +) -> Annotated[str, "A confirmation message indicating successful deletion"]: + """ + Delete a draft email using the Gmail API. + """ + service = _build_gmail_service(context) + + # Delete the draft + service.users().drafts().delete(userId="me", id=draft_email_id).execute() + return f"Draft email with ID {draft_email_id} deleted successfully." + + +# Email Management Tools +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.modify"], + ) +) +async def trash_email( + context: ToolContext, email_id: Annotated[str, "The ID of the email to trash"] +) -> Annotated[dict, "A dictionary containing the trashed email details"]: + """ + Move an email to the trash folder using the Gmail API. + """ + + service = _build_gmail_service(context) + + # Trash the email + trashed_email = service.users().messages().trash(userId="me", id=email_id).execute() + + email = parse_plain_text_email(trashed_email) + email["url"] = get_email_in_trash_url(trashed_email["id"]) + return email + + +# Draft Search Tools +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def list_draft_emails( + context: ToolContext, + n_drafts: Annotated[int, "Number of draft emails to read"] = 5, +) -> Annotated[dict, "A dictionary containing a list of draft email details"]: + """ + Lists draft emails in the user's draft mailbox using the Gmail API. + """ + service = _build_gmail_service(context) + + listed_drafts = service.users().drafts().list(userId="me").execute() + + if not listed_drafts: + return {"emails": []} + + draft_ids = [draft["id"] for draft in listed_drafts.get("drafts", [])][:n_drafts] + + emails = [] + for draft_id in draft_ids: + try: + draft_data = service.users().drafts().get(userId="me", id=draft_id).execute() + draft_details = parse_draft_email(draft_data) + if draft_details: + emails.append(draft_details) + except Exception as e: + raise GmailToolError( + message=f"Error reading draft email {draft_id}.", developer_message=str(e) + ) + + return {"emails": emails} + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def list_emails_by_header( + context: ToolContext, + sender: Annotated[str | None, "The name or email address of the sender of the email"] = None, + recipient: Annotated[str | None, "The name or email address of the recipient"] = None, + subject: Annotated[str | None, "Words to find in the subject of the email"] = None, + body: Annotated[str | None, "Words to find in the body of the email"] = None, + date_range: Annotated[DateRange | None, "The date range of the email"] = None, + label: Annotated[str | None, "The label name to filter by"] = None, + max_results: Annotated[int, "The maximum number of emails to return"] = 25, +) -> Annotated[ + dict, "A dictionary containing a list of email details matching the search criteria" +]: + """ + Search for emails by header using the Gmail API. + + At least one of the following parameters MUST be provided: sender, recipient, + subject, date_range, label, or body. + """ + service = _build_gmail_service(context) + # Ensure at least one search parameter is provided + if not any([sender, recipient, subject, body, label, date_range]): + raise RetryableToolError( + message=( + "At least one of sender, recipient, subject, body, label, query, " + "or date_range must be provided." + ), + developer_message=( + "At least one of sender, recipient, subject, body, label, query, " + "or date_range must be provided." + ), + ) + + # Check if label is valid + if label: + label_ids = get_label_ids(service, [label]) + + if not label_ids: + labels = service.users().labels().list(userId="me").execute().get("labels", []) + label_names = [label["name"] for label in labels] + raise RetryableToolError( + message=f"Invalid label: {label}", + developer_message=f"Invalid label: {label}", + additional_prompt_content=f"List of valid labels: {label_names}", + ) + + # Build a Gmail-style query string based on the filters + query = build_gmail_query_string(sender, recipient, subject, body, date_range, label) + + # Fetch matching messages. This fetches message metadata from Gmail + messages = fetch_messages(service, query, max_results) + + # If no messages found, return an empty list + if not messages: + return {"emails": []} + + # Process each message into a structured email object + emails = get_email_details(service, messages) + + # Return the list of emails in a dictionary with key "emails" + return {"emails": emails} + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def list_emails( + context: ToolContext, + n_emails: Annotated[int, "Number of emails to read"] = 5, +) -> Annotated[dict, "A dictionary containing a list of email details"]: + """ + Read emails from a Gmail account and extract plain text content. + """ + service = _build_gmail_service(context) + + messages = service.users().messages().list(userId="me").execute().get("messages", []) + + if not messages: + return {"emails": []} + + emails = [] + for msg in messages[:n_emails]: + try: + email_data = service.users().messages().get(userId="me", id=msg["id"]).execute() + email_details = parse_plain_text_email(email_data) + if email_details: + emails.append(email_details) + except Exception as e: + raise GmailToolError( + message=f"Error reading email {msg['id']}.", developer_message=str(e) + ) + return {"emails": emails} + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def search_threads( + context: ToolContext, + page_token: Annotated[ + str | None, "Page token to retrieve a specific page of results in the list" + ] = None, + max_results: Annotated[int, "The maximum number of threads to return"] = 10, + include_spam_trash: Annotated[bool, "Whether to include spam and trash in the results"] = False, + label_ids: Annotated[list[str] | None, "The IDs of labels to filter by"] = None, + sender: Annotated[str | None, "The name or email address of the sender of the email"] = None, + recipient: Annotated[str | None, "The name or email address of the recipient"] = None, + subject: Annotated[str | None, "Words to find in the subject of the email"] = None, + body: Annotated[str | None, "Words to find in the body of the email"] = None, + date_range: Annotated[DateRange | None, "The date range of the email"] = None, +) -> Annotated[dict, "A dictionary containing a list of thread details"]: + """Search for threads in the user's mailbox""" + service = _build_gmail_service(context) + + query = ( + build_gmail_query_string(sender, recipient, subject, body, date_range) + if any([sender, recipient, subject, body, date_range]) + else None + ) + + params = { + "userId": "me", + "maxResults": min(max_results, 500), + "pageToken": page_token, + "includeSpamTrash": include_spam_trash, + "labelIds": label_ids, + "q": query, + } + params = remove_none_values(params) + + threads: list[dict[str, Any]] = [] + next_page_token = None + # Paginate through thread pages until we have the desired number of threads + while len(threads) < max_results: + response = service.users().threads().list(**params).execute() + + threads.extend(response.get("threads", [])) + next_page_token = response.get("nextPageToken") + + if not next_page_token: + break + + params["pageToken"] = next_page_token + params["maxResults"] = min(max_results - len(threads), 500) + + return { + "threads": threads, + "num_threads": len(threads), + "next_page_token": next_page_token, + } + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def list_threads( + context: ToolContext, + page_token: Annotated[ + str | None, "Page token to retrieve a specific page of results in the list" + ] = None, + max_results: Annotated[int, "The maximum number of threads to return"] = 10, + include_spam_trash: Annotated[bool, "Whether to include spam and trash in the results"] = False, +) -> Annotated[dict, "A dictionary containing a list of thread details"]: + """List threads in the user's mailbox.""" + threads: dict[str, Any] = await search_threads( + context, page_token, max_results, include_spam_trash + ) + return threads + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def get_thread( + context: ToolContext, + thread_id: Annotated[str, "The ID of the thread to retrieve"], +) -> Annotated[dict, "A dictionary containing the thread details"]: + """Get the specified thread by ID.""" + params = { + "userId": "me", + "id": thread_id, + "format": "full", + } + params = remove_none_values(params) + + service = _build_gmail_service(context) + + thread = service.users().threads().get(**params).execute() + thread["messages"] = [parse_plain_text_email(message) for message in thread.get("messages", [])] + + return dict(thread) + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.modify"], + ) +) +async def change_email_labels( + context: ToolContext, + email_id: Annotated[str, "The ID of the email to modify labels for"], + labels_to_add: Annotated[list[str], "List of label names to add"], + labels_to_remove: Annotated[list[str], "List of label names to remove"], +) -> Annotated[dict, "List of labels that were added, removed, and not found"]: + """ + Add and remove labels from an email using the Gmail API. + """ + service = _build_gmail_service(context) + + add_labels = get_label_ids(service, labels_to_add) + remove_labels = get_label_ids(service, labels_to_remove) + + invalid_labels = ( + set(labels_to_add + labels_to_remove) - set(add_labels.keys()) - set(remove_labels.keys()) + ) + + if invalid_labels: + # prepare the list of valid labels + labels = service.users().labels().list(userId="me").execute().get("labels", []) + label_names = [label["name"] for label in labels] + + # raise a retryable error with the list of valid labels + raise RetryableToolError( + message=f"Invalid labels: {invalid_labels}", + developer_message=f"Invalid labels: {invalid_labels}", + additional_prompt_content=f"List of valid labels: {label_names}", + ) + + # Prepare the modification body with label IDs. + body = { + "addLabelIds": list(add_labels.values()), + "removeLabelIds": list(remove_labels.values()), + } + + try: # Modify the email labels. + service.users().messages().modify(userId="me", id=email_id, body=body).execute() + + except Exception as e: + raise GmailToolError( + message=f"Error modifying labels for email {email_id}", developer_message=str(e) + ) + + # Confirmation JSON with lists for added and removed labels. + confirmation = { + "addedLabels": list(add_labels.keys()), + "removedLabels": list(remove_labels.keys()), + } + + return {"confirmation": dict(confirmation)} + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def list_labels( + context: ToolContext, +) -> Annotated[dict, "A dictionary containing a list of label details"]: + """List all the labels in the user's mailbox.""" + + service = _build_gmail_service(context) + + labels = service.users().labels().list(userId="me").execute().get("labels", []) + + return {"labels": labels} + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/gmail.labels"], + ) +) +async def create_label( + context: ToolContext, + label_name: Annotated[str, "The name of the label to create"], +) -> Annotated[dict, "The details of the created label"]: + """Create a new label in the user's mailbox.""" + + service = _build_gmail_service(context) + label = service.users().labels().create(userId="me", body={"name": label_name}).execute() + + return {"label": label} diff --git a/toolkits/gmail/arcade_gmail/utils.py b/toolkits/gmail/arcade_gmail/utils.py new file mode 100644 index 00000000..48aa5ca2 --- /dev/null +++ b/toolkits/gmail/arcade_gmail/utils.py @@ -0,0 +1,509 @@ +import logging +import re +from base64 import urlsafe_b64decode, urlsafe_b64encode +from datetime import datetime, timedelta +from email.message import EmailMessage +from email.mime.text import MIMEText +from enum import Enum +from typing import Any + +from arcade_tdk import ToolContext +from bs4 import BeautifulSoup +from google.oauth2.credentials import Credentials +from googleapiclient.discovery import build + +from arcade_gmail.enums import ( + GmailAction, + GmailReplyToWhom, +) +from arcade_gmail.exceptions import GmailServiceError, GmailToolError + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +class DateRange(Enum): + TODAY = "today" + YESTERDAY = "yesterday" + LAST_7_DAYS = "last_7_days" + LAST_30_DAYS = "last_30_days" + THIS_MONTH = "this_month" + LAST_MONTH = "last_month" + THIS_YEAR = "this_year" + + def to_date_query(self) -> str: + today = datetime.now() + result = "after:" + comparison_date = today + + if self == DateRange.YESTERDAY: + comparison_date = today - timedelta(days=1) + elif self == DateRange.LAST_7_DAYS: + comparison_date = today - timedelta(days=7) + elif self == DateRange.LAST_30_DAYS: + comparison_date = today - timedelta(days=30) + elif self == DateRange.THIS_MONTH: + comparison_date = today.replace(day=1) + elif self == DateRange.LAST_MONTH: + comparison_date = (today.replace(day=1) - timedelta(days=1)).replace(day=1) + elif self == DateRange.THIS_YEAR: + comparison_date = today.replace(month=1, day=1) + elif self == DateRange.LAST_MONTH: + comparison_date = (today.replace(month=1, day=1) - timedelta(days=1)).replace( + month=1, day=1 + ) + + return result + comparison_date.strftime("%Y/%m/%d") + + +def build_email_message( + recipient: str, + subject: str, + body: str, + cc: list[str] | None = None, + bcc: list[str] | None = None, + replying_to: dict[str, Any] | None = None, + action: GmailAction = GmailAction.SEND, +) -> dict[str, Any]: + if replying_to: + body = build_reply_body(body, replying_to) + + message: EmailMessage | MIMEText + + if action == GmailAction.SEND: + message = EmailMessage() + message.set_content(body) + elif action == GmailAction.DRAFT: + message = MIMEText(body) + + message["To"] = recipient + message["Subject"] = subject + + if cc: + message["Cc"] = ",".join(cc) + if bcc: + message["Bcc"] = ",".join(bcc) + if replying_to: + message["In-Reply-To"] = replying_to["header_message_id"] + message["References"] = f"{replying_to['header_message_id']}, {replying_to['references']}" + + encoded_message = urlsafe_b64encode(message.as_bytes()).decode() + + data = {"raw": encoded_message} + + if replying_to: + data["threadId"] = replying_to["thread_id"] + + return data + + +def _build_gmail_service(context: ToolContext) -> Any: + """ + Private helper function to build and return the Gmail service client. + + Args: + context (ToolContext): The context containing authorization details. + + Returns: + googleapiclient.discovery.Resource: An authorized Gmail API service instance. + """ + try: + credentials = Credentials( + context.authorization.token + if context.authorization and context.authorization.token + else "" + ) + except Exception as e: + raise GmailServiceError(message="Failed to build Gmail service.", developer_message=str(e)) + + return build("gmail", "v1", credentials=credentials) + + +def build_gmail_query_string( + sender: str | None = None, + recipient: str | None = None, + subject: str | None = None, + body: str | None = None, + date_range: DateRange | None = None, + label: str | None = None, +) -> str: + """Helper function to build a query string + for Gmail list_emails_by_header and search_threads tools. + """ + query = [] + if sender: + query.append(f"from:{sender}") + if recipient: + query.append(f"to:{recipient}") + if subject: + query.append(f"subject:{subject}") + if body: + query.append(body) + if date_range: + query.append(date_range.to_date_query()) + if label: + query.append(f"label:{label}") + return " ".join(query) + + +def get_label_ids(service: Any, label_names: list[str]) -> dict[str, str]: + """ + Retrieve label IDs for given label names. + Returns a dictionary mapping label names to their IDs. + + Args: + service: Authenticated Gmail API service instance. + label_names: List of label names to retrieve IDs for. + + Returns: + A dictionary mapping found label names to their corresponding IDs. + """ + try: + # Fetch all existing labels from Gmail + labels = service.users().labels().list(userId="me").execute().get("labels", []) + except Exception as e: + raise GmailToolError(message="Failed to list labels.", developer_message=str(e)) from e + + # Create a mapping from label names to their IDs + label_id_map = {label["name"]: label["id"] for label in labels} + + found_labels = {} + for name in label_names: + label_id = label_id_map.get(name) + if label_id: + found_labels[name] = label_id + else: + logger.warning(f"Label '{name}' does not exist") + + return found_labels + + +def fetch_messages(service: Any, query_string: str, limit: int) -> list[dict[str, Any]]: + """ + Helper function to fetch messages from Gmail API for the list_emails_by_header tool. + """ + response = ( + service.users() + .messages() + .list(userId="me", q=query_string, maxResults=limit or 100) + .execute() + ) + return response.get("messages", []) # type: ignore[no-any-return] + + +def remove_none_values(params: dict) -> dict: + """ + Remove None values from a dictionary. + :param params: The dictionary to clean + :return: A new dictionary with None values removed + """ + return {k: v for k, v in params.items() if v is not None} + + +def build_reply_recipients( + replying_to: dict[str, Any], current_user_email_address: str, reply_to_whom: GmailReplyToWhom +) -> str: + if reply_to_whom == GmailReplyToWhom.ONLY_THE_SENDER: + recipients = [replying_to["from"]] + elif reply_to_whom == GmailReplyToWhom.EVERY_RECIPIENT: + recipients = [replying_to["from"], *replying_to["to"].split(",")] + else: + raise ValueError(f"Unsupported reply_to_whom value: {reply_to_whom}") + + recipients = [ + email_address.strip() + for email_address in recipients + if email_address.strip().lower() != current_user_email_address.lower().strip() + ] + + return ", ".join(recipients) + + +def get_draft_url(draft_id: str) -> str: + return f"https://mail.google.com/mail/u/0/#drafts/{draft_id}" + + +def get_sent_email_url(sent_email_id: str) -> str: + return f"https://mail.google.com/mail/u/0/#sent/{sent_email_id}" + + +def get_email_details(service: Any, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Retrieves full message data for each message ID in the given list and extracts email details. + + :param service: Authenticated Gmail API service instance. + :param messages: A list of dictionaries, each representing a message with an 'id' key. + :return: A list of dictionaries, each containing parsed email details. + """ + + emails = [] + for msg in messages: + try: + # Fetch the full message data from Gmail using the message ID + email_data = service.users().messages().get(userId="me", id=msg["id"]).execute() + # Parse the raw email data into a structured form + email_details = parse_plain_text_email(email_data) + # Only add the details if parsing was successful + if email_details: + emails.append(email_details) + except Exception as e: + # Log any errors encountered while trying to fetch or parse a message + raise GmailToolError( + message=f"Error reading email {msg['id']}.", developer_message=str(e) + ) + return emails + + +def get_email_in_trash_url(email_id: str) -> str: + return f"https://mail.google.com/mail/u/0/#trash/{email_id}" + + +def parse_draft_email(draft_email_data: dict[str, Any]) -> dict[str, str]: + """ + Parse draft email data and extract relevant information. + + Args: + draft_email_data (Dict[str, Any]): Raw draft email data from Gmail API. + + Returns: + dict[str, str]: Parsed draft email details + """ + message = draft_email_data.get("message", {}) + payload = message.get("payload", {}) + headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])} + + body_data = _get_email_plain_text_body(payload) + + return { + "id": draft_email_data.get("id", ""), + "thread_id": draft_email_data.get("threadId", ""), + "from": headers.get("from", ""), + "date": headers.get("internaldate", ""), + "subject": headers.get("subject", ""), + "body": _clean_email_body(body_data) if body_data else "", + } + + +def _clean_email_body(body: str | None) -> str: + """ + Remove HTML tags and clean up email body text while preserving most content. + + Args: + body (str): The raw email body text. + + Returns: + str: Cleaned email body text. + """ + if not body: + return "" + + try: + # Remove HTML tags using BeautifulSoup + soup = BeautifulSoup(body, "html.parser") + text = soup.get_text(separator=" ") + + # Clean up the text + cleaned_text = _clean_text(text) + + return cleaned_text.strip() + except Exception: + logger.exception("Error cleaning email body") + return body + + +def _get_email_plain_text_body(payload: dict[str, Any]) -> str | None: + """ + Extract email body from payload, handling 'multipart/alternative' parts. + + Args: + payload (Dict[str, Any]): Email payload data. + + Returns: + str | None: Decoded email body or None if not found. + """ + # Direct body extraction + if "body" in payload and payload["body"].get("data"): + return _clean_email_body(urlsafe_b64decode(payload["body"]["data"]).decode()) + + # Handle multipart and alternative parts + return _clean_email_body(_extract_plain_body(payload.get("parts", []))) + + +def _extract_plain_body(parts: list) -> str | None: + """ + Recursively extract the email body from parts, handling both plain text and HTML. + + Args: + parts (List[Dict[str, Any]]): List of email parts. + + Returns: + str | None: Decoded and cleaned email body or None if not found. + """ + for part in parts: + mime_type = part.get("mimeType") + + if mime_type == "text/plain" and "data" in part.get("body", {}): + return urlsafe_b64decode(part["body"]["data"]).decode() + + elif mime_type.startswith("multipart/"): + subparts = part.get("parts", []) + body = _extract_plain_body(subparts) + if body: + return body + + return _extract_html_body(parts) + + +def _extract_html_body(parts: list) -> str | None: + """ + Recursively extract the email body from parts, handling only HTML. + + Args: + parts (List[Dict[str, Any]]): List of email parts. + + Returns: + str | None: Decoded and cleaned email body or None if not found. + """ + for part in parts: + mime_type = part.get("mimeType") + + if mime_type == "text/html" and "data" in part.get("body", {}): + html_content = urlsafe_b64decode(part["body"]["data"]).decode() + return html_content + + elif mime_type.startswith("multipart/"): + subparts = part.get("parts", []) + body = _extract_html_body(subparts) + if body: + return body + + return None + + +def _clean_text(text: str) -> str: + """ + Clean up the text while preserving most content. + + Args: + text (str): The input text. + + Returns: + str: Cleaned text. + """ + # Replace multiple newlines with a single newline + text = re.sub(r"\n+", "\n", text) + + # Replace multiple spaces with a single space + text = re.sub(r"\s+", " ", text) + + # Remove leading/trailing whitespace from each line + text = "\n".join(line.strip() for line in text.split("\n")) + + return text + + +def parse_plain_text_email(email_data: dict[str, Any]) -> dict[str, Any]: + """ + Parse email data and extract relevant information. + Only returns the plain text body. + + Args: + email_data (dict[str, Any]): Raw email data from Gmail API. + + Returns: + dict[str, str]: Parsed email details + """ + payload = email_data.get("payload", {}) + headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])} + + body_data = _get_email_plain_text_body(payload) + + email_details = { + "id": email_data.get("id", ""), + "thread_id": email_data.get("threadId", ""), + "label_ids": email_data.get("labelIds", []), + "history_id": email_data.get("historyId", ""), + "snippet": email_data.get("snippet", ""), + "to": headers.get("to", ""), + "cc": headers.get("cc", ""), + "from": headers.get("from", ""), + "reply_to": headers.get("reply-to", ""), + "in_reply_to": headers.get("in-reply-to", ""), + "references": headers.get("references", ""), + "header_message_id": headers.get("message-id", ""), + "date": headers.get("date", ""), + "subject": headers.get("subject", ""), + "body": body_data or "", + } + + return email_details + + +def build_reply_body(body: str, replying_to: dict[str, Any]) -> str: + attribution = f"On {replying_to['date']}, {replying_to['from']} wrote:" + lines = replying_to["plain_text_body"].split("\n") + quoted_plain = "\n".join([f"> {line}" for line in lines]) + return f"{body}\n\n{attribution}\n\n{quoted_plain}" + + +def parse_multipart_email(email_data: dict[str, Any]) -> dict[str, Any]: + """ + Parse email data and extract relevant information. + Returns the plain text and HTML body along with the images. + + Args: + email_data (Dict[str, Any]): Raw email data from Gmail API. + + Returns: + dict[str, Any]: Parsed email details + """ + + payload = email_data.get("payload", {}) + headers = {d["name"].lower(): d["value"] for d in payload.get("headers", [])} + + # Extract different parts of the email + plain_text_body = _get_email_plain_text_body(payload) + html_body = _get_email_html_body(payload) + + email_details = { + "id": email_data.get("id", ""), + "thread_id": email_data.get("threadId", ""), + "label_ids": email_data.get("labelIds", []), + "history_id": email_data.get("historyId", ""), + "snippet": email_data.get("snippet", ""), + "to": headers.get("to", ""), + "cc": headers.get("cc", ""), + "from": headers.get("from", ""), + "reply_to": headers.get("reply-to", ""), + "in_reply_to": headers.get("in-reply-to", ""), + "references": headers.get("references", ""), + "header_message_id": headers.get("message-id", ""), + "date": headers.get("date", ""), + "subject": headers.get("subject", ""), + "plain_text_body": plain_text_body or _clean_email_body(html_body), + "html_body": html_body or "", + } + + return email_details + + +def _get_email_html_body(payload: dict[str, Any]) -> str | None: + """ + Extract email html body from payload, handling 'multipart/alternative' parts. + + Args: + payload (Dict[str, Any]): Email payload data. + + Returns: + str | None: Decoded email body or None if not found. + """ + # Direct body extraction + if "body" in payload and payload["body"].get("data"): + return urlsafe_b64decode(payload["body"]["data"]).decode() + + # Handle multipart and alternative parts + return _extract_html_body(payload.get("parts", [])) diff --git a/toolkits/gmail/evals/eval_google_gmail.py b/toolkits/gmail/evals/eval_google_gmail.py new file mode 100644 index 00000000..7271ce78 --- /dev/null +++ b/toolkits/gmail/evals/eval_google_gmail.py @@ -0,0 +1,431 @@ +import json + +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_gmail +from arcade_gmail.enums import GmailReplyToWhom +from arcade_gmail.tools import ( + get_thread, + list_emails_by_header, + list_threads, + reply_to_email, + search_threads, + send_email, + write_draft_reply_email, +) +from arcade_gmail.utils import DateRange + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + + +catalog = ToolCatalog() +catalog.add_module(arcade_gmail) + + +@tool_eval() +def gmail_eval_suite() -> EvalSuite: + """Create an evaluation suite for Gmail tools.""" + suite = EvalSuite( + name="Gmail Tools Evaluation", + system_message="You are an AI assistant that can send and manage emails using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Send email to user with clear username", + user_message="Send a email to johndoe@example.com saying 'Hello, can we meet at 3 PM?'. CC his boss janedoe@example.com", + expected_tool_calls=[ + ExpectedToolCall( + func=send_email, + args={ + "subject": "Meeting Request", + "body": "Hello, can we meet at 3 PM?", + "recipient": "johndoe@example.com", + "cc": ["janedoe@example.com"], + "bcc": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=0.125), + SimilarityCritic(critic_field="body", weight=0.25), + BinaryCritic(critic_field="recipient", weight=0.25), + BinaryCritic(critic_field="cc", weight=0.25), + BinaryCritic(critic_field="bcc", weight=0.125), + ], + ) + + suite.add_case( + name="Simple list threads", + user_message="Get 42 threads like right now i even wanna see the ones in my trash", + expected_tool_calls=[ + ExpectedToolCall( + func=list_threads, + args={"max_results": 42, "include_spam_trash": True}, + ) + ], + critics=[ + BinaryCritic(critic_field="max_results", weight=0.5), + BinaryCritic(critic_field="include_spam_trash", weight=0.5), + ], + ) + + history = [ + {"role": "user", "content": "list 1 thread"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_X8V5Hw9iJ3wfB8WMZf8omAMi", + "type": "function", + "function": {"name": "Google_ListThreads", "arguments": '{"max_results":1}'}, + } + ], + }, + { + "role": "tool", + "content": '{"next_page_token":"10321400718999360131","num_threads":1,"threads":[{"historyId":"61691","id":"1934a8f8deccb749","snippet":"Hi Joe, I hope this email finds you well. Thank you for being a part of our community."}]}', + "tool_call_id": "call_X8V5Hw9iJ3wfB8WMZf8omAMi", + "name": "Google_ListThreads", + }, + { + "role": "assistant", + "content": "Here is one email thread:\n\n- **Snippet:** Hi Joe, I hope this email finds you well. Thank you for being a part of our community.\n- **Thread ID:** 1934a8f8deccb749\n- **History ID:** 61691", + }, + ] + suite.add_case( + name="List threads with history", + user_message="Get the next 5 threads", + additional_messages=history, + expected_tool_calls=[ + ExpectedToolCall( + func=list_threads, + args={ + "max_results": 5, + "page_token": "10321400718999360131", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="max_results", weight=0.2), + BinaryCritic(critic_field="page_token", weight=0.8), + ], + ) + + suite.add_case( + name="Search threads", + user_message="Search for threads from johndoe@example.com to janedoe@example.com about that talk about 'Arcade AI' from yesterday", + expected_tool_calls=[ + ExpectedToolCall( + func=search_threads, + args={ + "sender": "johndoe@example.com", + "recipient": "janedoe@example.com", + "body": "Arcade AI", + "date_range": DateRange.YESTERDAY, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="sender", weight=0.25), + BinaryCritic(critic_field="recipient", weight=0.25), + SimilarityCritic(critic_field="body", weight=0.25), + BinaryCritic(critic_field="date_range", weight=0.25), + ], + ) + + suite.add_case( + name="Get a thread by ID", + user_message="Get the thread r-124325435467568867667878874565464564563523424323524235242412", + expected_tool_calls=[ + ExpectedToolCall( + func=get_thread, + args={ + "thread_id": "r-124325435467568867667878874565464564563523424323524235242412", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="thread_id", weight=1.0), + ], + ) + + return suite + + +@tool_eval() +def gmail_reply_eval_suite() -> EvalSuite: + """Create an evaluation suite for Gmail reply tools.""" + suite = EvalSuite( + name="Gmail Reply Tools Evaluation", + system_message="You are an AI assistant that can send and manage emails using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + email_history = [ + {"role": "user", "content": "get the latest emails I received from johndoe@gmail.com"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_jowMD7aB9sVPClOfvNof7Llu", + "type": "function", + "function": { + "name": "Google_ListEmailsByHeader", + "arguments": json.dumps({ + "sender": "johndoe@gmail.com", + "max_results": 5, + }), + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps({ + "emails": [ + { + "body": "test 1", + "cc": "", + "date": "Tue, 11 Feb 2025 11:33:08 -0300", + "from": "John Doe ", + "header_message_id": "", + "history_id": "123456", + "id": "q34759q435nv", + "in_reply_to": "", + "label_ids": ["INBOX"], + "references": "", + "reply_to": "", + "snippet": "test 1", + "subject": "test 1", + "thread_id": "345y6v3596", + "to": "myself@gmail.com", + }, + { + "body": "test 2", + "cc": "", + "date": "Mon, 20 Jan 2025 13:04:42 -0800", + "from": "John Doe ", + "header_message_id": "<28745ytvw8745ct4@mail.gmail.com>", + "history_id": "3456758", + "id": "9475tvy24578yx", + "in_reply_to": "", + "label_ids": [], + "references": "", + "reply_to": "", + "snippet": "test 2", + "subject": "test 2", + "thread_id": "249576v3496", + "to": "myself@gmail.com", + }, + ] + }), + "tool_call_id": "call_jowMD7aB9sVPClOfvNof7Llu", + "name": "Google_ListEmailsByHeader", + }, + { + "role": "assistant", + "content": "Here are the latest emails you received from johndoe@gmail.com:\n\n1. **Subject**: test 1\n - **Date**: Tue, 11 Feb 2025 11:33:08 -0300\n - **Snippet**: test 1\n\n2. **Subject**: test 2\n - **Date**: Mon, 20 Jan 2025 13:04:42 -0800\n - **Snippet**: test 2\n\nIf you need further details from any specific email, let me know!", + }, + ] + + suite.add_case( + name="Reply to an email", + user_message="Reply to the email from johndoe@example.com about 'test 2' saying 'tested and working well'", + expected_tool_calls=[ + ExpectedToolCall( + func=reply_to_email, + args={ + "reply_to_message_id": "9475tvy24578yx", + "body": "tested and working well", + "reply_to_whom": GmailReplyToWhom.ONLY_THE_SENDER.value, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=1 / 7), + SimilarityCritic(critic_field="body", weight=1 / 7), + BinaryCritic(critic_field="recipient", weight=1 / 7), + BinaryCritic(critic_field="cc", weight=1 / 7), + BinaryCritic(critic_field="bcc", weight=1 / 7), + BinaryCritic(critic_field="reply_to_whom", weight=1 / 7), + BinaryCritic(critic_field="reply_to_message_id", weight=1 / 7), + ], + additional_messages=email_history, + ) + + suite.add_case( + name="Reply to an email with every recipient", + user_message="Reply to every recipient in the email from johndoe@example.com about 'test 2' saying 'tested and working well'", + expected_tool_calls=[ + ExpectedToolCall( + func=reply_to_email, + args={ + "reply_to_message_id": "9475tvy24578yx", + "body": "tested and working well", + "reply_to_whom": GmailReplyToWhom.EVERY_RECIPIENT.value, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=1 / 7), + SimilarityCritic(critic_field="body", weight=1 / 7), + BinaryCritic(critic_field="recipient", weight=1 / 7), + BinaryCritic(critic_field="cc", weight=1 / 7), + BinaryCritic(critic_field="bcc", weight=1 / 7), + BinaryCritic(critic_field="reply_to_whom", weight=1 / 7), + BinaryCritic(critic_field="reply_to_message_id", weight=1 / 7), + ], + additional_messages=email_history, + ) + + suite.add_case( + name="Reply to an email with bcc", + user_message="Reply to the email from johndoe@example.com about 'test 2' saying 'tested and working well' and send it to janedoe@example.com as bcc as well", + expected_tool_calls=[ + ExpectedToolCall( + func=reply_to_email, + args={ + "reply_to_message_id": "9475tvy24578yx", + "body": "tested and working well", + "bcc": ["janedoe@example.com"], + "reply_to_whom": GmailReplyToWhom.ONLY_THE_SENDER.value, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=1 / 7), + SimilarityCritic(critic_field="body", weight=1 / 7), + BinaryCritic(critic_field="recipient", weight=1 / 7), + BinaryCritic(critic_field="cc", weight=1 / 7), + BinaryCritic(critic_field="bcc", weight=1 / 7), + BinaryCritic(critic_field="reply_to_whom", weight=1 / 7), + BinaryCritic(critic_field="reply_to_message_id", weight=1 / 7), + ], + additional_messages=email_history, + ) + + suite.add_case( + name="Write draft reply", + user_message="Write a draft reply to the email from johndoe@example.com about 'test 2' saying 'tested and working well'", + expected_tool_calls=[ + ExpectedToolCall( + func=write_draft_reply_email, + args={ + "reply_to_message_id": "9475tvy24578yx", + "body": "tested and working well", + "reply_to_whom": GmailReplyToWhom.ONLY_THE_SENDER.value, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=1 / 7), + SimilarityCritic(critic_field="body", weight=1 / 7), + BinaryCritic(critic_field="recipient", weight=1 / 7), + BinaryCritic(critic_field="cc", weight=1 / 7), + BinaryCritic(critic_field="bcc", weight=1 / 7), + BinaryCritic(critic_field="reply_to_message_id", weight=1 / 7), + BinaryCritic(critic_field="reply_to_whom", weight=1 / 7), + ], + additional_messages=email_history, + ) + + suite.add_case( + name="Write draft reply to every recipient", + user_message="Write a draft reply to every recipient in the email from johndoe@example.com about 'test 2' saying 'tested and working well'", + expected_tool_calls=[ + ExpectedToolCall( + func=write_draft_reply_email, + args={ + "reply_to_message_id": "9475tvy24578yx", + "body": "tested and working well", + "reply_to_whom": GmailReplyToWhom.EVERY_RECIPIENT.value, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=1 / 7), + SimilarityCritic(critic_field="body", weight=1 / 7), + BinaryCritic(critic_field="recipient", weight=1 / 7), + BinaryCritic(critic_field="cc", weight=1 / 7), + BinaryCritic(critic_field="bcc", weight=1 / 7), + BinaryCritic(critic_field="reply_to_whom", weight=0.125), + BinaryCritic(critic_field="reply_to_message_id", weight=1 / 7), + ], + additional_messages=email_history, + ) + + return suite + + +@tool_eval() +def gmail_list_emails_by_header_eval_suite() -> EvalSuite: + """Create an evaluation suite for Gmail tools.""" + suite = EvalSuite( + name="Gmail list_emails_by_header tool evaluation", + system_message="You are an AI assistant that can send and manage emails using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="List emails by header using date-range", + user_message="List all emails from johndoe@example.com to janedoe@example.com about 'Arcade AI' from yesterday", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails_by_header, + args={ + "sender": "johndoe@example.com", + "recipient": "janedoe@example.com", + "subject": "Arcade AI", + "date_range": DateRange.YESTERDAY.value, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="sender", weight=1 / 4), + BinaryCritic(critic_field="recipient", weight=1 / 4), + SimilarityCritic(critic_field="subject", weight=1 / 4), + BinaryCritic(critic_field="date_range", weight=1 / 4), + ], + ) + + suite.add_case( + name="List emails by header using date-range", + user_message="List all emails from johndoe@example.com to janedoe@example.com about 'Arcade AI' from the last month", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails_by_header, + args={ + "sender": "johndoe@example.com", + "recipient": "janedoe@example.com", + "subject": "Arcade AI", + "date_range": DateRange.LAST_MONTH.value, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="sender", weight=1 / 4), + BinaryCritic(critic_field="recipient", weight=1 / 4), + SimilarityCritic(critic_field="subject", weight=1 / 4), + BinaryCritic(critic_field="date_range", weight=1 / 4), + ], + ) + + return suite diff --git a/toolkits/gmail/pyproject.toml b/toolkits/gmail/pyproject.toml new file mode 100644 index 00000000..f34b4d98 --- /dev/null +++ b/toolkits/gmail/pyproject.toml @@ -0,0 +1,64 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_gmail" +version = "2.0.0" +description = "Arcade.dev LLM tools for Gmail" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "beautifulsoup4>=4.10.0,<5.0.0", + "google-api-core>=2.19.1,<3.0.0", + "google-api-python-client>=2.137.0,<3.0.0", + "google-auth>=2.32.0,<3.0.0", + "google-auth-httplib2>=0.2.0,<1.0.0", + "googleapis-common-protos>=1.63.2,<2.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0rc1,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_gmail/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_gmail",] diff --git a/toolkits/gmail/tests/__init__.py b/toolkits/gmail/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/gmail/tests/test_gmail.py b/toolkits/gmail/tests/test_gmail.py new file mode 100644 index 00000000..a1873c2b --- /dev/null +++ b/toolkits/gmail/tests/test_gmail.py @@ -0,0 +1,951 @@ +from base64 import urlsafe_b64encode +from email.message import EmailMessage +from unittest.mock import MagicMock, patch + +import pytest +from arcade_tdk import ToolAuthorizationContext, ToolContext +from arcade_tdk.errors import ToolExecutionError +from googleapiclient.errors import HttpError + +from arcade_gmail.enums import GmailReplyToWhom +from arcade_gmail.tools import ( + delete_draft_email, + get_thread, + list_draft_emails, + list_emails, + list_emails_by_header, + list_threads, + reply_to_email, + search_threads, + send_draft_email, + send_email, + trash_email, + update_draft_email, + write_draft_email, +) +from arcade_gmail.utils import ( + build_reply_body, + parse_draft_email, + parse_multipart_email, + parse_plain_text_email, +) + + +@pytest.fixture +def mock_context(): + mock_auth = ToolAuthorizationContext(token="fake-token") # noqa: S106 + return ToolContext(authorization=mock_auth) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_send_email(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Test happy path + result = await send_email( + context=mock_context, + subject="Test Subject", + body="Test Body", + recipient="test@example.com", + ) + + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result + + # Test http error + mock_service.users().messages().send().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid recipient"}}', + ) + + with pytest.raises(ToolExecutionError): + await send_email( + context=mock_context, + subject="Test Subject", + body="Test Body", + recipient="invalid@example.com", + ) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_write_draft_email(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Test happy path + result = await write_draft_email( + context=mock_context, + subject="Test Draft Subject", + body="Test Draft Body", + recipient="draft@example.com", + ) + + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result + + # Test http error + mock_service.users().drafts().create().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await write_draft_email( + context=mock_context, + subject="Test Draft Subject", + body="Test Draft Body", + recipient="draft@example.com", + ) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_update_draft_email(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Test happy path + result = await update_draft_email( + context=mock_context, + draft_email_id="draft123", + subject="Updated Subject", + body="Updated Body", + recipient="updated@example.com", + ) + + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result + + # Test http error + mock_service.users().drafts().update().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Draft not found"}}', + ) + + with pytest.raises(ToolExecutionError): + await update_draft_email( + context=mock_context, + draft_email_id="nonexistent_draft", + subject="Updated Subject", + body="Updated Body", + recipient="updated@example.com", + ) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_send_draft_email(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Test happy path + result = await send_draft_email(context=mock_context, email_id="draft456") + + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result + + # Test http error + mock_service.users().drafts().send().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Draft not found"}}', + ) + + with pytest.raises(ToolExecutionError): + await send_draft_email(context=mock_context, email_id="nonexistent_draft") + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_delete_draft_email(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Test happy path + result = await delete_draft_email(context=mock_context, draft_email_id="draft789") + + assert "Draft email with ID" in result + assert "deleted successfully" in result + + # Test http error + mock_service.users().drafts().delete().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Draft not found"}}', + ) + + with pytest.raises(ToolExecutionError): + await delete_draft_email(context=mock_context, draft_email_id="nonexistent_draft") + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +@patch("arcade_gmail.tools.gmail.parse_draft_email") +async def test_get_draft_emails(mock_parse_draft_email, mock_build, mock_context): + # Setup test data + mock_drafts_list_response = { + "drafts": [ + { + "id": "r9999999999999999999", + "message": {"id": "0000000000000000", "threadId": "0000000000000000"}, + } + ], + "resultSizeEstimate": 1, + } + mock_drafts_get_response = { + "id": "r9999999999999999999", + "message": { + "id": "0000000000000000", + "threadId": "0000000000000000", + "labelIds": ["DRAFT"], + "snippet": "Hello! This is a test. Best regards, John", + "payload": { + "partId": "", + "mimeType": "text/plain", + "filename": "", + "headers": [ + {"name": "to", "value": "test@arcade-ai.com"}, + {"name": "subject", "value": "New Draft"}, + {"name": "Date", "value": "Mon, 16 Sep 2024 13:02:10 -0400"}, + {"name": "From", "value": "john-doe@arcade-ai.com"}, + ], + "body": { + "size": 41, + "data": "SGVsbG8hIFRoaXMgaXMgYSB0ZXN0LgoKQmVzdCByZWdhcmRzLApCb2I=", + }, + }, + "sizeEstimate": 453, + "historyId": "7061", + "internalDate": "1726506130000", + }, + } + + # Setup mocking + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Mock the response from the Gmail list drafts API + mock_service.users().drafts().list().execute.return_value = mock_drafts_list_response + + # Mock the response from the Gmail get drafts API + mock_service.users().drafts().get().execute.return_value = mock_drafts_get_response + + # Mock the parse_draft_email function since parse_draft_email doesn't accept object of type MagicMock + mock_parse_draft_email.return_value = parse_draft_email(mock_drafts_get_response) + + # Test happy path + result = await list_draft_emails(context=mock_context, n_drafts=2) + + assert isinstance(result, dict) + assert "emails" in result + assert len(result["emails"]) == 1 + assert all("id" in draft and "subject" in draft for draft in result["emails"]) + + # Test http error + mock_service.users().drafts().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await list_draft_emails(context=mock_context, n_drafts=2) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +@patch("arcade_gmail.tools.gmail.parse_plain_text_email") +async def test_search_emails_by_header(mock_parse_plain_text_email, mock_build, mock_context): + # Setup test data + mock_messages_list_response = { + "messages": [ + {"id": "191fbc8ddce0f433", "threadId": "191fbc8ddce0f433"}, + {"id": "191fbc0ea11efa90", "threadId": "191fbc0ea11efa90"}, + ], + "nextPageToken": "00755945214480102915", + "resultSizeEstimate": 201, + } + mock_messages_get_response = { + "id": "191f2cf4d24bf23d", + "threadId": "191f2cf4d24bf23d", + "labelIds": ["UNREAD", "IMPORTANT", "CATEGORY_UPDATES", "INBOX"], + "snippet": "Hey User, Your personal access token (classic) "ArcadeAI" with admin:enterprise, admin:gpg_key, admin:org, admin:org_hook, admin:public_key, admin:repo_hook, admin:ssh_signing_key,", + "payload": { + "partId": "", + "mimeType": "text/plain", + "filename": "", + "headers": [ + {"name": "Delivered-To", "value": "example@arcade-ai.com"}, + {"name": "Date", "value": "Sat, 14 Sep 2024 16:12:37 -0700"}, + {"name": "From", "value": "GitHub \u003cnoreply@github.com\u003e"}, + {"name": "To", "value": "example@arcade-ai.com"}, + { + "name": "Subject", + "value": "[GitHub] Your personal access token (classic) has expired", + }, + ], + "body": { + "size": 605, + "data": "SGV5IEBFcmljR3VzdGluLA0KDQpZb3VyIHBlcnNvbmFsIGFjY2VzcyB0b2tlbiAoY2xhc3NpYykgIkFyY2FkZUFJIiB3aXRoIGFkbWluOmVudGVycHJpc2UsIGFkbWluOmdwZ19rZXksIGFkbWluOm9yZywgYWRtaW46b3JnX2hvb2ssIGFkbWluOnB1YmxpY19rZXksIGFkbWluOnJlcG9faG9vaywgYWRtaW46c3NoX3NpZ25pbmdfa2V5LCBhdWRpdF9sb2csIGNvZGVzcGFjZSwgY29waWxvdCwgZGVsZXRlOnBhY2thZ2VzLCBkZWxldGVfcmVwbywgZ2lzdCwgbm90aWZpY2F0aW9ucywgcHJvamVjdCwgcmVwbywgdXNlciwgd29ya2Zsb3csIHdyaXRlOmRpc2N1c3Npb24sIGFuZCB3cml0ZTpwYWNrYWdlcyBzY29wZXMgaGFzIGV4cGlyZWQuDQoNCklmIHRoaXMgdG9rZW4gaXMgc3RpbGwgbmVlZGVkLCB2aXNpdCBodHRwczovL2dpdGh1Yi5jb20vc2V0dGluZ3MvdG9rZW5zLzE3MTM2OTg2MTMvcmVnZW5lcmF0ZSB0byBnZW5lcmF0ZSBhbiBlcXVpdmFsZW50Lg0KDQpJZiB5b3UgcnVuIGludG8gcHJvYmxlbXMsIHBsZWFzZSBjb250YWN0IHN1cHBvcnQgYnkgdmlzaXRpbmcgaHR0cHM6Ly9naXRodWIuY29tL2NvbnRhY3QNCg0KVGhhbmtzLA0KVGhlIEdpdEh1YiBUZWFtDQo=", + }, + }, + "sizeEstimate": 4512, + "historyId": "5508", + "internalDate": "1726355557000", + } + + # Setup mocking + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Mock the response from the Gmail list messages API + mock_service.users().messages().list().execute.return_value = mock_messages_list_response + + # Mock the response from the Gmail get messages API + mock_service.users().messages().get().execute.return_value = mock_messages_get_response + + # Mock the parse_plain_text_email function since parse_plain_text_email doesn't accept object of type MagicMock + mock_parse_plain_text_email.return_value = parse_plain_text_email(mock_messages_get_response) + + # Test happy path + result = await list_emails_by_header( + context=mock_context, sender="noreply@github.com", max_results=2 + ) + + assert isinstance(result, dict) + assert "emails" in result + assert len(result["emails"]) == 2 + assert all("id" in email and "subject" in email for email in result["emails"]) + + # Test http error + mock_service.users().messages().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await list_emails_by_header( + context=mock_context, sender="noreply@github.com", max_results=2 + ) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +@patch("arcade_gmail.tools.gmail.parse_plain_text_email") +async def test_get_emails(mock_parse_plain_text_email, mock_build, mock_context): + # Setup test data + mock_messages_list_response = { + "messages": [ + {"id": "191fbc8ddce0f433", "threadId": "191fbc8ddce0f433"}, + ], + "nextPageToken": "00755945214480102915", + "resultSizeEstimate": 1, + } + mock_messages_get_response = { + "id": "191f2cf4d24bf23d", + "threadId": "191f2cf4d24bf23d", + "labelIds": ["UNREAD", "IMPORTANT", "CATEGORY_UPDATES", "INBOX"], + "snippet": "Hey User, Your personal access token (classic) "ArcadeAI" with admin:enterprise, admin:gpg_key, admin:org, admin:org_hook, admin:public_key, admin:repo_hook, admin:ssh_signing_key,", + "payload": { + "partId": "", + "mimeType": "text/plain", + "filename": "", + "headers": [ + {"name": "Delivered-To", "value": "example@arcade-ai.com"}, + {"name": "Date", "value": "Sat, 14 Sep 2024 16:12:37 -0700"}, + {"name": "From", "value": "GitHub \u003cnoreply@github.com\u003e"}, + {"name": "To", "value": "example@arcade-ai.com"}, + { + "name": "Subject", + "value": "[GitHub] Your personal access token (classic) has expired", + }, + ], + "body": { + "size": 605, + "data": "SGV5IEBFcmljR3VzdGluLA0KDQpZb3VyIHBlcnNvbmFsIGFjY2VzcyB0b2tlbiAoY2xhc3NpYykgIkFyY2FkZUFJIiB3aXRoIGFkbWluOmVudGVycHJpc2UsIGFkbWluOmdwZ19rZXksIGFkbWluOm9yZywgYWRtaW46b3JnX2hvb2ssIGFkbWluOnB1YmxpY19rZXksIGFkbWluOnJlcG9faG9vaywgYWRtaW46c3NoX3NpZ25pbmdfa2V5LCBhdWRpdF9sb2csIGNvZGVzcGFjZSwgY29waWxvdCwgZGVsZXRlOnBhY2thZ2VzLCBkZWxldGVfcmVwbywgZ2lzdCwgbm90aWZpY2F0aW9ucywgcHJvamVjdCwgcmVwbywgdXNlciwgd29ya2Zsb3csIHdyaXRlOmRpc2N1c3Npb24sIGFuZCB3cml0ZTpwYWNrYWdlcyBzY29wZXMgaGFzIGV4cGlyZWQuDQoNCklmIHRoaXMgdG9rZW4gaXMgc3RpbGwgbmVlZGVkLCB2aXNpdCBodHRwczovL2dpdGh1Yi5jb20vc2V0dGluZ3MvdG9rZW5zLzE3MTM2OTg2MTMvcmVnZW5lcmF0ZSB0byBnZW5lcmF0ZSBhbiBlcXVpdmFsZW50Lg0KDQpJZiB5b3UgcnVuIGludG8gcHJvYmxlbXMsIHBsZWFzZSBjb250YWN0IHN1cHBvcnQgYnkgdmlzaXRpbmcgaHR0cHM6Ly9naXRodWIuY29tL2NvbnRhY3QNCg0KVGhhbmtzLA0KVGhlIEdpdEh1YiBUZWFtDQo=", + }, + }, + "sizeEstimate": 4512, + "historyId": "5508", + "internalDate": "1726355557000", + } + + # Setup mocking + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Mock the response from the Gmail list messages API + mock_service.users().messages().list().execute.return_value = mock_messages_list_response + + # Mock the Gmail get messages API + mock_service.users().messages().get().execute.return_value = mock_messages_get_response + + # Mock the parse_plain_text_email function since parse_plain_text_email doesn't accept object of type MagicMock + mock_parse_plain_text_email.return_value = parse_plain_text_email(mock_messages_get_response) + + # Test happy path + result = await list_emails(context=mock_context, n_emails=1) + + assert isinstance(result, dict) + assert "emails" in result + assert len(result["emails"]) == 1 + assert "id" in result["emails"][0] + assert "subject" in result["emails"][0] + assert "date" in result["emails"][0] + assert "body" in result["emails"][0] + + # Test http error + mock_service.users().messages().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await list_emails(context=mock_context, n_emails=1) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_trash_email(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Test happy path + email_id = "123456" + result = await trash_email(context=mock_context, email_id=email_id) + + assert isinstance(result, dict) + assert "id" in result + assert "thread_id" in result + assert "subject" in result + assert "body" in result + + # Test http error + mock_service.users().messages().trash().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Email not found"}}', + ) + + with pytest.raises(ToolExecutionError): + await trash_email(context=mock_context, email_id="nonexistent_email") + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_search_threads(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Setup mock response data + mock_threads_list_response = { + "threads": [ + { + "id": "thread1", + "snippet": "Thread snippet 1", + }, + { + "id": "thread2", + "snippet": "Thread snippet 2", + }, + ], + "nextPageToken": "next_token_123", + "resultSizeEstimate": 2, + } + + # Mock the Gmail API threads().list() method + mock_service.users().threads().list().execute.return_value = mock_threads_list_response + + # Test happy path + result = await search_threads( + context=mock_context, + sender="test@example.com", + max_results=2, + ) + + assert isinstance(result, dict) + assert "threads" in result + assert len(result["threads"]) == 2 + assert result["threads"][0]["id"] == "thread1" + assert "next_page_token" in result + + # Test error handling + mock_service.users().threads().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await search_threads( + context=mock_context, + sender="test@example.com", + max_results=2, + ) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_list_threads(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Setup mock response data + mock_threads_list_response = { + "threads": [ + { + "id": "thread1", + "snippet": "Thread snippet 1", + }, + { + "id": "thread2", + "snippet": "Thread snippet 2", + }, + ], + "nextPageToken": "next_token_123", + "resultSizeEstimate": 2, + } + + # Mock the Gmail API threads().list() method + mock_service.users().threads().list().execute.return_value = mock_threads_list_response + + # Test happy path + result = await list_threads( + context=mock_context, + max_results=2, + ) + + assert isinstance(result, dict) + assert "threads" in result + assert len(result["threads"]) == 2 + assert result["threads"][0]["id"] == "thread1" + assert "next_page_token" in result + + # Test error handling + mock_service.users().threads().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await list_threads( + context=mock_context, + max_results=2, + ) + + +@pytest.mark.asyncio +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_get_thread(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Setup mock response data + mock_thread_get_response = { + "id": "thread1", + "messages": [ + { + "id": "message1", + "snippet": "Message snippet 1", + }, + { + "id": "message2", + "snippet": "Message snippet 2", + }, + ], + } + + # Mock the Gmail API threads().get() method + mock_service.users().threads().get().execute.return_value = mock_thread_get_response + + # Test happy path + result = await get_thread( + context=mock_context, + thread_id="thread1", + ) + + assert isinstance(result, dict) + assert "id" in result + assert result["id"] == "thread1" + assert "messages" in result + assert len(result["messages"]) == 2 + assert result["messages"][0]["id"] == "message1" + + # Test error handling + mock_service.users().threads().get().execute.side_effect = HttpError( + resp=MagicMock(status=404), + content=b'{"error": {"message": "Thread not found"}}', + ) + + with pytest.raises(ToolExecutionError): + await get_thread( + context=mock_context, + thread_id="invalid_thread", + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "reply_to_whom, expected_to, expected_cc", + [ + ( + GmailReplyToWhom.EVERY_RECIPIENT, + "sender@example.com, to1@example.com, to2@example.com", + "cc1@example.com, cc2@example.com", + ), + ( + GmailReplyToWhom.ONLY_THE_SENDER, + "sender@example.com", + "", + ), + ], +) +@patch("arcade_gmail.tools.gmail._build_gmail_service") +async def test_reply_to_email(mock_build, reply_to_whom, expected_to, expected_cc, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + original_message = { + "id": "id123456", + "threadId": "thread123456", + "payload": { + "headers": [ + {"name": "Message-ID", "value": "id123456"}, + {"name": "Subject", "value": "test"}, + {"name": "From", "value": "sender@example.com"}, + {"name": "To", "value": "to1@example.com, to2@example.com, test@example.com"}, + {"name": "Cc", "value": "cc1@example.com, cc2@example.com"}, + {"name": "References", "value": "thread123456"}, + ], + }, + } + + mock_service.users().getProfile().execute.return_value = {"emailAddress": "test@example.com"} + mock_service.users().messages().get().execute.return_value = original_message + + result = await reply_to_email( + context=mock_context, + body="test", + reply_to_message_id="id123456", + reply_to_whom=reply_to_whom, + ) + + assert isinstance(result, dict) + assert "url" in result + + replying_to = parse_multipart_email(original_message) + expected_body = build_reply_body("test", replying_to) + + expected_message = EmailMessage() + expected_message.set_content(expected_body) + expected_message["To"] = expected_to + expected_message["Subject"] = "Re: test" + if expected_cc: + expected_message["Cc"] = expected_cc + expected_message["In-Reply-To"] = "id123456" + expected_message["References"] = "id123456, thread123456" + + mock_service.users().messages().send.assert_called_once_with( + userId="me", + body={ + "raw": urlsafe_b64encode(expected_message.as_bytes()).decode(), + "threadId": "thread123456", + }, + ) + + +def test_parse_multipart_email_full(): + """ + Test parsing a multipart email with both plain text and HTML bodies. + """ + email_data = { + "id": "email123", + "threadId": "thread123", + "labelIds": ["INBOX", "UNREAD"], + "historyId": "history123", + "snippet": "This is a test email.", + "payload": { + "headers": [ + {"name": "To", "value": "recipient@example.com"}, + {"name": "From", "value": "sender@example.com"}, + {"name": "Subject", "value": "Test Email"}, + {"name": "Date", "value": "Mon, 1 Jan 2024 10:00:00 -0000"}, + ], + "body": {"size": 100, "data": "VGhpcyBpcyBhIHRlc3QgZW1haWwu"}, + }, + } + + with ( + patch("arcade_gmail.utils._get_email_plain_text_body") as mock_plain, + patch("arcade_gmail.utils._get_email_html_body") as mock_html, + patch("arcade_gmail.utils._clean_email_body") as mock_clean, + ): + # Mock the helper functions + mock_plain.return_value = "This is a test email." + mock_html.return_value = "

This is a test email.

" + mock_clean.return_value = "This is a test email." + + result = parse_multipart_email(email_data) + + assert result["id"] == "email123" + assert result["thread_id"] == "thread123" + assert result["label_ids"] == ["INBOX", "UNREAD"] + assert result["snippet"] == "This is a test email." + assert result["to"] == "recipient@example.com" + assert result["from"] == "sender@example.com" + assert result["subject"] == "Test Email" + assert result["date"] == "Mon, 1 Jan 2024 10:00:00 -0000" + assert result["plain_text_body"] == "This is a test email." + assert result["html_body"] == "

This is a test email.

" + + +def test_parse_multipart_email_plain_only(): + """ + Test parsing an email with only a plain text body. + """ + email_data = { + "id": "email456", + "threadId": "thread456", + "labelIds": ["INBOX"], + "historyId": "history456", + "snippet": "Plain text only email.", + "payload": { + "headers": [ + {"name": "To", "value": "recipient2@example.com"}, + {"name": "From", "value": "sender2@example.com"}, + {"name": "Subject", "value": "Plain Text Email"}, + {"name": "Date", "value": "Tue, 2 Feb 2024 11:00:00 -0000"}, + ], + "body": {"size": 150, "data": "UGxhaW4gdGV4dCBvbmx5IGVtYWlsLg=="}, + }, + } + + with ( + patch("arcade_gmail.utils._get_email_plain_text_body") as mock_plain, + patch("arcade_gmail.utils._get_email_html_body") as mock_html, + patch("arcade_gmail.utils._clean_email_body") as mock_clean, + ): + # Mock the helper functions + mock_plain.return_value = "Plain text only email." + mock_html.return_value = None + mock_clean.return_value = "Plain text only email." + + result = parse_multipart_email(email_data) + + assert result["id"] == "email456" + assert result["thread_id"] == "thread456" + assert result["label_ids"] == ["INBOX"] + assert result["snippet"] == "Plain text only email." + assert result["to"] == "recipient2@example.com" + assert result["from"] == "sender2@example.com" + assert result["subject"] == "Plain Text Email" + assert result["date"] == "Tue, 2 Feb 2024 11:00:00 -0000" + assert result["plain_text_body"] == "Plain text only email." + assert result["html_body"] == "" + + +def test_parse_multipart_email_html_only(): + """ + Test parsing an email with only an HTML body. + """ + email_data = { + "id": "email789", + "threadId": "thread789", + "labelIds": ["SENT"], + "historyId": "history789", + "snippet": "HTML only email.", + "payload": { + "headers": [ + {"name": "To", "value": "recipient3@example.com"}, + {"name": "From", "value": "sender3@example.com"}, + {"name": "Subject", "value": "HTML Email"}, + {"name": "Date", "value": "Wed, 3 Mar 2024 12:00:00 -0000"}, + ], + "body": {"size": 200, "data": "PGh0bWw+VGhpcyBpcyBIVE1MIGVtYWlsLjwvaHRtbD4="}, + }, + } + + with ( + patch("arcade_gmail.utils._get_email_plain_text_body") as mock_plain, + patch("arcade_gmail.utils._get_email_html_body") as mock_html, + patch("arcade_gmail.utils._clean_email_body") as mock_clean, + ): + # Mock the helper functions + mock_plain.return_value = None + mock_html.return_value = "This is HTML email." + mock_clean.return_value = "This is HTML email." + + result = parse_multipart_email(email_data) + + assert result["id"] == "email789" + assert result["thread_id"] == "thread789" + assert result["label_ids"] == ["SENT"] + assert result["snippet"] == "HTML only email." + assert result["to"] == "recipient3@example.com" + assert result["from"] == "sender3@example.com" + assert result["subject"] == "HTML Email" + assert result["date"] == "Wed, 3 Mar 2024 12:00:00 -0000" + assert result["plain_text_body"] == "This is HTML email." + assert result["html_body"] == "This is HTML email." + + +def test_parse_multipart_email_missing_payload(): + """ + Test parsing an email with missing payload. + """ + email_data = { + "id": "email000", + "threadId": "thread000", + "labelIds": ["INBOX"], + "historyId": "history000", + "snippet": "Missing payload email.", + # 'payload' key is missing + } + + result = parse_multipart_email(email_data) + + # Since payload is missing, headers and bodies should be default or empty + assert result["id"] == "email000" + assert result["thread_id"] == "thread000" + assert result["label_ids"] == ["INBOX"] + assert result["snippet"] == "Missing payload email." + assert result["to"] == "" + assert result["from"] == "" + assert result["subject"] == "" + assert result["date"] == "" + assert result["plain_text_body"] == "" + assert result["html_body"] == "" + + +def test_parse_multipart_email_missing_headers(): + """ + Test parsing an email with missing headers in the payload. + """ + email_data = { + "id": "email111", + "threadId": "thread111", + "labelIds": ["INBOX"], + "historyId": "history111", + "snippet": "Missing headers email.", + "payload": { + # 'headers' key is missing + "body": {"size": 100, "data": "VGltZWw="} + }, + } + + with ( + patch("arcade_gmail.utils._get_email_plain_text_body") as mock_plain, + patch("arcade_gmail.utils._get_email_html_body") as mock_html, + patch("arcade_gmail.utils._clean_email_body") as mock_clean, + ): + # Mock the helper functions + mock_plain.return_value = "Timeel" + mock_html.return_value = "

Timeel

" + mock_clean.return_value = "Timeel" + + result = parse_multipart_email(email_data) + + assert result["id"] == "email111" + assert result["thread_id"] == "thread111" + assert result["label_ids"] == ["INBOX"] + assert result["snippet"] == "Missing headers email." + assert result["to"] == "" + assert result["from"] == "" + assert result["subject"] == "" + assert result["date"] == "" + assert result["plain_text_body"] == "Timeel" + assert result["html_body"] == "

Timeel

" + + +def test_parse_multipart_email_missing_fields(): + """ + Test parsing an email with some missing fields in headers. + """ + email_data = { + "id": "email222", + "threadId": "thread222", + "labelIds": ["INBOX"], + "historyId": "history222", + "snippet": "Missing some headers.", + "payload": { + "headers": [ + {"name": "From", "value": "sender4@example.com"}, + {"name": "Subject", "value": "Partial Headers"}, + # 'To' and 'Date' headers are missing + ], + "body": {"size": 100, "data": "TWlzc2luZyBzb21lIGhlYWRlcnMu"}, + }, + } + + with ( + patch("arcade_gmail.utils._get_email_plain_text_body") as mock_plain, + patch("arcade_gmail.utils._get_email_html_body") as mock_html, + patch("arcade_gmail.utils._clean_email_body") as mock_clean, + ): + # Mock the helper functions + mock_plain.return_value = "Missing some headers." + mock_html.return_value = None + mock_clean.return_value = "Missing some headers." + + result = parse_multipart_email(email_data) + + assert result["id"] == "email222" + assert result["thread_id"] == "thread222" + assert result["label_ids"] == ["INBOX"] + assert result["snippet"] == "Missing some headers." + assert result["to"] == "" + assert result["from"] == "sender4@example.com" + assert result["subject"] == "Partial Headers" + assert result["date"] == "" + assert result["plain_text_body"] == "Missing some headers." + assert result["html_body"] == "" + + +def test_parse_multipart_email_empty(): + """ + Test parsing an empty email data. + """ + email_data = {} + + result = parse_multipart_email(email_data) + + assert result["id"] == "" + assert result["thread_id"] == "" + assert result["label_ids"] == [] + assert result["snippet"] == "" + assert result["to"] == "" + assert result["from"] == "" + assert result["subject"] == "" + assert result["date"] == "" + assert result["plain_text_body"] == "" + assert result["html_body"] == "" + + +def test_parse_multipart_email_invalid_payload_structure(): + """ + Test parsing an email with an invalid payload structure. + """ + email_data = { + "id": "email333", + "threadId": "thread333", + "labelIds": ["INBOX"], + "historyId": "history333", + "snippet": "Invalid payload structure.", + "payload": { + "headers": "This should be a list, not a string", + "body": {"size": 100, "data": "SW52YWxpZCBwYXlsb2Fk"}, + }, + } + + with pytest.raises(TypeError): + parse_multipart_email(email_data) diff --git a/toolkits/google_calendar/.pre-commit-config.yaml b/toolkits/google_calendar/.pre-commit-config.yaml new file mode 100644 index 00000000..f714abbc --- /dev/null +++ b/toolkits/google_calendar/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_calendar/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_calendar/.ruff.toml b/toolkits/google_calendar/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/google_calendar/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_calendar/Makefile b/toolkits/google_calendar/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_calendar/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_calendar/arcade_google_calendar/__init__.py b/toolkits/google_calendar/arcade_google_calendar/__init__.py new file mode 100644 index 00000000..d77091d3 --- /dev/null +++ b/toolkits/google_calendar/arcade_google_calendar/__init__.py @@ -0,0 +1,17 @@ +from arcade_google_calendar.tools import ( + create_event, + delete_event, + find_time_slots_when_everyone_is_free, + list_calendars, + list_events, + update_event, +) + +__all__ = [ + "create_event", + "delete_event", + "find_time_slots_when_everyone_is_free", + "list_calendars", + "list_events", + "update_event", +] diff --git a/toolkits/google_calendar/arcade_google_calendar/enums.py b/toolkits/google_calendar/arcade_google_calendar/enums.py new file mode 100644 index 00000000..5002df52 --- /dev/null +++ b/toolkits/google_calendar/arcade_google_calendar/enums.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class EventVisibility(Enum): + DEFAULT = "default" + PUBLIC = "public" + PRIVATE = "private" + CONFIDENTIAL = "confidential" + + +class SendUpdatesOptions(Enum): + NONE = "none" # No notifications are sent + ALL = "all" # Notifications are sent to all guests + EXTERNAL_ONLY = "externalOnly" # Notifications are sent to non-Google Calendar guests only. diff --git a/toolkits/google_calendar/arcade_google_calendar/tools/__init__.py b/toolkits/google_calendar/arcade_google_calendar/tools/__init__.py new file mode 100644 index 00000000..bd347920 --- /dev/null +++ b/toolkits/google_calendar/arcade_google_calendar/tools/__init__.py @@ -0,0 +1,17 @@ +from arcade_google_calendar.tools.calendar import ( + create_event, + delete_event, + find_time_slots_when_everyone_is_free, + list_calendars, + list_events, + update_event, +) + +__all__ = [ + "create_event", + "delete_event", + "find_time_slots_when_everyone_is_free", + "list_calendars", + "list_events", + "update_event", +] diff --git a/toolkits/google_calendar/arcade_google_calendar/tools/calendar.py b/toolkits/google_calendar/arcade_google_calendar/tools/calendar.py new file mode 100644 index 00000000..e6ddb9df --- /dev/null +++ b/toolkits/google_calendar/arcade_google_calendar/tools/calendar.py @@ -0,0 +1,510 @@ +import json +from datetime import datetime, timedelta +from typing import Annotated, Any +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Google +from arcade_tdk.errors import RetryableToolError +from googleapiclient.errors import HttpError + +from arcade_google_calendar.enums import EventVisibility, SendUpdatesOptions +from arcade_google_calendar.utils import ( + build_calendar_service, + build_oauth_service, + compute_free_time_intersection, + parse_datetime, +) + + +@tool( + requires_auth=Google( + scopes=[ + "https://www.googleapis.com/auth/calendar.readonly", + "https://www.googleapis.com/auth/calendar.events", + ] + ) +) +async def list_calendars( + context: ToolContext, + max_results: Annotated[ + int, "The maximum number of calendars to return. Up to 250 calendars, defaults to 10." + ] = 10, + show_deleted: Annotated[bool, "Whether to show deleted calendars. Defaults to False"] = False, + show_hidden: Annotated[bool, "Whether to show hidden calendars. Defaults to False"] = False, + next_page_token: Annotated[ + str | None, "The token to retrieve the next page of calendars. Optional." + ] = None, +) -> Annotated[dict, "A dictionary containing the calendars accessible by the end user"]: + """ + List all calendars accessible by the user. + """ + max_results = max(1, min(max_results, 250)) + service = build_calendar_service(context.get_auth_token_or_empty()) + calendars = ( + service.calendarList() + .list( + pageToken=next_page_token, + showDeleted=show_deleted, + showHidden=show_hidden, + maxResults=max_results, + ) + .execute() + ) + + items = calendars.get("items", []) + keys = ["description", "id", "summary", "timeZone"] + relevant_items = [{k: i.get(k) for k in keys if i.get(k)} for i in items] + return { + "next_page_token": calendars.get("nextPageToken"), + "num_calendars": len(relevant_items), + "calendars": relevant_items, + } + + +@tool( + requires_auth=Google( + scopes=[ + "https://www.googleapis.com/auth/calendar.readonly", + "https://www.googleapis.com/auth/calendar.events", + ], + ) +) +async def create_event( + context: ToolContext, + summary: Annotated[str, "The title of the event"], + start_datetime: Annotated[ + str, + "The datetime when the event starts in ISO 8601 format, e.g., '2024-12-31T15:30:00'.", + ], + end_datetime: Annotated[ + str, + "The datetime when the event ends in ISO 8601 format, e.g., '2024-12-31T17:30:00'.", + ], + calendar_id: Annotated[ + str, "The ID of the calendar to create the event in, usually 'primary'." + ] = "primary", + description: Annotated[str | None, "The description of the event"] = None, + location: Annotated[str | None, "The location of the event"] = None, + visibility: Annotated[EventVisibility, "The visibility of the event"] = EventVisibility.DEFAULT, + attendee_emails: Annotated[ + list[str] | None, + "The list of attendee emails. Must be valid email addresses e.g., username@domain.com.", + ] = None, +) -> Annotated[dict, "A dictionary containing the created event details"]: + """Create a new event/meeting/sync/meetup in the specified calendar.""" + + service = build_calendar_service(context.get_auth_token_or_empty()) + + # Get the calendar's time zone + calendar = service.calendars().get(calendarId=calendar_id).execute() + time_zone = calendar["timeZone"] + + # Parse datetime strings + start_dt = parse_datetime(start_datetime, time_zone) + end_dt = parse_datetime(end_datetime, time_zone) + + event: dict[str, Any] = { + "summary": summary, + "description": description, + "location": location, + "start": {"dateTime": start_dt.isoformat(), "timeZone": time_zone}, + "end": {"dateTime": end_dt.isoformat(), "timeZone": time_zone}, + "visibility": visibility.value, + } + + if attendee_emails: + event["attendees"] = [{"email": email} for email in attendee_emails] + + created_event = service.events().insert(calendarId=calendar_id, body=event).execute() + return {"event": created_event} + + +@tool( + requires_auth=Google( + scopes=[ + "https://www.googleapis.com/auth/calendar.readonly", + "https://www.googleapis.com/auth/calendar.events", + ], + ) +) +async def list_events( + context: ToolContext, + min_end_datetime: Annotated[ + str, + "Filter by events that end on or after this datetime in ISO 8601 format, " + "e.g., '2024-09-15T09:00:00'.", + ], + max_start_datetime: Annotated[ + str, + "Filter by events that start before this datetime in ISO 8601 format, " + "e.g., '2024-09-16T17:00:00'.", + ], + calendar_id: Annotated[str, "The ID of the calendar to list events from"] = "primary", + max_results: Annotated[int, "The maximum number of events to return"] = 10, +) -> Annotated[dict, "A dictionary containing the list of events"]: + """ + List events from the specified calendar within the given datetime range. + + min_end_datetime serves as the lower bound (exclusive) for an event's end time. + max_start_datetime serves as the upper bound (exclusive) for an event's start time. + + For example: + If min_end_datetime is set to 2024-09-15T09:00:00 and max_start_datetime + is set to 2024-09-16T17:00:00, the function will return events that: + 1. End after 09:00 on September 15, 2024 (exclusive) + 2. Start before 17:00 on September 16, 2024 (exclusive) + This means an event starting at 08:00 on September 15 and + ending at 10:00 on September 15 would be included, but an + event starting at 17:00 on September 16 would not be included. + """ + service = build_calendar_service(context.get_auth_token_or_empty()) + + # Get the calendar's time zone + calendar = service.calendars().get(calendarId=calendar_id).execute() + time_zone = calendar["timeZone"] + + # Parse datetime strings + min_end_dt = parse_datetime(min_end_datetime, time_zone) + max_start_dt = parse_datetime(max_start_datetime, time_zone) + + if min_end_dt > max_start_dt: + min_end_dt, max_start_dt = max_start_dt, min_end_dt + + events_result = ( + service.events() + .list( + calendarId=calendar_id, + timeMin=min_end_dt.isoformat(), + timeMax=max_start_dt.isoformat(), + maxResults=max_results, + singleEvents=True, + orderBy="startTime", + ) + .execute() + ) + + items_keys = [ + "attachments", + "attendees", + "creator", + "description", + "end", + "eventType", + "htmlLink", + "id", + "location", + "organizer", + "start", + "summary", + "visibility", + ] + + events = [ + {key: event[key] for key in items_keys if key in event} + for event in events_result.get("items", []) + ] + + return {"events_count": len(events), "events": events} + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/calendar.events"], + ) +) +async def update_event( + context: ToolContext, + event_id: Annotated[str, "The ID of the event to update"], + updated_start_datetime: Annotated[ + str | None, + "The updated datetime that the event starts in ISO 8601 format, " + "e.g., '2024-12-31T15:30:00'.", + ] = None, + updated_end_datetime: Annotated[ + str | None, + "The updated datetime that the event ends in ISO 8601 format, e.g., '2024-12-31T17:30:00'.", + ] = None, + updated_calendar_id: Annotated[ + str | None, "The updated ID of the calendar containing the event." + ] = None, + updated_summary: Annotated[str | None, "The updated title of the event"] = None, + updated_description: Annotated[str | None, "The updated description of the event"] = None, + updated_location: Annotated[str | None, "The updated location of the event"] = None, + updated_visibility: Annotated[EventVisibility | None, "The visibility of the event"] = None, + attendee_emails_to_add: Annotated[ + list[str] | None, + "The list of attendee emails to add. Must be valid email addresses " + "e.g., username@domain.com.", + ] = None, + attendee_emails_to_remove: Annotated[ + list[str] | None, + "The list of attendee emails to remove. Must be valid email addresses " + "e.g., username@domain.com.", + ] = None, + send_updates: Annotated[ + SendUpdatesOptions, + "Should attendees be notified of the update? (none, all, external_only)", + ] = SendUpdatesOptions.ALL, +) -> Annotated[ + str, + "A string containing the updated event details, including the event ID, update timestamp, " + "and a link to view the updated event.", +]: + """ + Update an existing event in the specified calendar with the provided details. + Only the provided fields will be updated; others will remain unchanged. + + `updated_start_datetime` and `updated_end_datetime` are + independent and can be provided separately. + """ + service = build_calendar_service(context.get_auth_token_or_empty()) + + calendar = service.calendars().get(calendarId="primary").execute() + time_zone = calendar["timeZone"] + + try: + event = service.events().get(calendarId="primary", eventId=event_id).execute() + except HttpError: + valid_events_with_id = ( + service.events() + .list( + calendarId="primary", + timeMin=(datetime.now() - timedelta(days=2)).isoformat(), + timeMax=(datetime.now() + timedelta(days=365)).isoformat(), + maxResults=50, + singleEvents=True, + orderBy="startTime", + ) + .execute() + ) + raise RetryableToolError( + f"Event with ID {event_id} not found.", + additional_prompt_content=( + f"Here is a list of valid events. The event_id parameter must match one of these: " + f"{valid_events_with_id}" + ), + retry_after_ms=1000, + developer_message=( + f"Event with ID {event_id} not found. Please try again with a valid event ID." + ), + ) + + update_fields = { + "start": {"dateTime": updated_start_datetime, "timeZone": time_zone} + if updated_start_datetime + else None, + "end": {"dateTime": updated_end_datetime, "timeZone": time_zone} + if updated_end_datetime + else None, + "calendarId": updated_calendar_id, + "sendUpdates": send_updates.value if send_updates else None, + "summary": updated_summary, + "description": updated_description, + "location": updated_location, + "visibility": updated_visibility.value if updated_visibility else None, + } + + event.update({k: v for k, v in update_fields.items() if v is not None}) + + if attendee_emails_to_remove: + event["attendees"] = [ + attendee + for attendee in event.get("attendees", []) + if attendee.get("email", "").lower() + not in [email.lower() for email in attendee_emails_to_remove] + ] + + if attendee_emails_to_add: + existing_emails = { + attendee.get("email", "").lower() for attendee in event.get("attendees", []) + } + new_attendees = [ + {"email": email} + for email in attendee_emails_to_add + if email.lower() not in existing_emails + ] + event["attendees"] = event.get("attendees", []) + new_attendees + + updated_event = ( + service.events() + .update( + calendarId="primary", + eventId=event_id, + sendUpdates=send_updates.value, + body=event, + ) + .execute() + ) + return ( + f"Event with ID {event_id} successfully updated at {updated_event['updated']}. " + f"View updated event at {updated_event['htmlLink']}" + ) + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/calendar.events"], + ) +) +async def delete_event( + context: ToolContext, + event_id: Annotated[str, "The ID of the event to delete"], + calendar_id: Annotated[str, "The ID of the calendar containing the event"] = "primary", + send_updates: Annotated[ + SendUpdatesOptions, "Specifies which attendees to notify about the deletion" + ] = SendUpdatesOptions.ALL, +) -> Annotated[str, "A string containing the deletion confirmation message"]: + """Delete an event from Google Calendar.""" + service = build_calendar_service(context.get_auth_token_or_empty()) + + service.events().delete( + calendarId=calendar_id, eventId=event_id, sendUpdates=send_updates.value + ).execute() + + notification_message = "" + if send_updates == SendUpdatesOptions.ALL: + notification_message = "Notifications were sent to all attendees." + elif send_updates == SendUpdatesOptions.EXTERNAL_ONLY: + notification_message = "Notifications were sent to external attendees only." + elif send_updates == SendUpdatesOptions.NONE: + notification_message = "No notifications were sent to attendees." + + return ( + f"Event with ID '{event_id}' successfully deleted from calendar '{calendar_id}'. " + f"{notification_message}" + ) + + +# TODO: would be nice to have a "min_slot_duration" parameter +# TODO: find a way to have "include_weekends" parameter without confusing LLMs +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/calendar.readonly"], + ), +) +async def find_time_slots_when_everyone_is_free( + context: ToolContext, + email_addresses: Annotated[ + list[str] | None, + "The list of email addresses from people in the same organization domain (apart from the " + "currently logged in user) to search for free time slots. Defaults to None, which will " + "return free time slots for the current user only.", + ] = None, + start_date: Annotated[ + str | None, + "The start date to search for time slots in the format 'YYYY-MM-DD'. Defaults to today's " + "date. It will search starting from this date at the time 00:00:00.", + ] = None, + end_date: Annotated[ + str | None, + "The end date to search for time slots in the format 'YYYY-MM-DD'. Defaults to seven days " + "from the start date. It will search until this date at the time 23:59:59.", + ] = None, + start_time_boundary: Annotated[ + str, + "Will return free slots in any given day starting from this time in the format 'HH:MM'. " + "Defaults to '08:00', which is a usual business hour start time.", + ] = "08:00", + end_time_boundary: Annotated[ + str, + "Will return free slots in any given day until this time in the format 'HH:MM'. " + "Defaults to '18:00', which is a usual business hour end time.", + ] = "18:00", +) -> Annotated[ + dict, + "A dictionary with the free slots and the timezone in which time slots are represented.", +]: + """ + Provides time slots when everyone is free within a given date range and time boundaries. + """ + + # Build google api services + oauth_service = build_oauth_service(context.get_auth_token_or_empty()) + calendar_service = build_calendar_service(context.get_auth_token_or_empty()) + + email_addresses = email_addresses or [] + + if isinstance(email_addresses, str): + email_addresses = [email_addresses] + + # Add the currently logged in user to the list of email addresses + user_info = oauth_service.userinfo().get().execute() + if user_info["email"] not in email_addresses: + email_addresses.append(user_info["email"]) + + # Get the timezone of the currently logged in user + calendar = calendar_service.calendars().get(calendarId="primary").execute() + timezone_name = calendar.get("timeZone") + + try: + tz = ZoneInfo(timezone_name) + # If the calendar timezone name is not supported by Python's zoneinfo, use UTC + except ZoneInfoNotFoundError: + timezone_name = "UTC" + tz = ZoneInfo("UTC") + + # Set default start and end dates, if not provided by the caller + start_date = start_date or datetime.now(tz=tz).date().isoformat() + end_date = end_date or (datetime.now(tz=tz).date() + timedelta(days=7)).isoformat() + + # Parse start and end dates to datetime objects + start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace( + hour=0, minute=0, second=0, microsecond=0, tzinfo=tz + ) + end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace( + hour=23, minute=59, second=59, microsecond=0, tzinfo=tz + ) + + # Get the busy slots from the calendars of the users + freebusy_response = ( + calendar_service.freebusy() + .query( + body={ + "timeMin": start_datetime.isoformat(), + "timeMax": end_datetime.isoformat(), + "timeZone": timezone_name, + "items": [{"id": email_address} for email_address in email_addresses], + } + ) + .execute() + ) + busy_slots = freebusy_response["calendars"] + + response_errors = [] + + for email in email_addresses: + if "errors" not in busy_slots[email]: + continue + errors = busy_slots[email]["errors"] + for error in errors: + response_errors.append( + f"Error retrieving free slots from calendar of '{email}': " + f"{error.get('reason', 'not determined')}" + ) + + if response_errors: + raise RetryableToolError( + "Error retrieving free slots from calendars of one or more users.", + additional_prompt_content=json.dumps(response_errors), + retry_after_ms=1000, + developer_message="Error retrieving free slots from calendars of one or more users.", + ) + + # Compute the free slots + free_slots = compute_free_time_intersection( + busy_data=busy_slots, + global_start=start_datetime, + global_end=end_datetime, + start_time_boundary=datetime.strptime(start_time_boundary, "%H:%M") + .time() + .replace(tzinfo=tz), + end_time_boundary=datetime.strptime(end_time_boundary, "%H:%M").time().replace(tzinfo=tz), + include_weekends=True, + tz=tz, + ) + + return { + "free_slots": free_slots, + "timezone": timezone_name, + } diff --git a/toolkits/google_calendar/arcade_google_calendar/utils.py b/toolkits/google_calendar/arcade_google_calendar/utils.py new file mode 100644 index 00000000..d70da2bb --- /dev/null +++ b/toolkits/google_calendar/arcade_google_calendar/utils.py @@ -0,0 +1,249 @@ +import logging +from datetime import date, datetime, time, timedelta, timezone +from typing import Any +from zoneinfo import ZoneInfo + +from google.oauth2.credentials import Credentials +from googleapiclient.discovery import Resource, build + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +def parse_datetime(datetime_str: str, time_zone: str) -> datetime: + """ + Parse a datetime string in ISO 8601 format and ensure it is timezone-aware. + + Args: + datetime_str (str): The datetime string to parse. Expected format: 'YYYY-MM-DDTHH:MM:SS'. + time_zone (str): The timezone to apply if the datetime string is naive. + + Returns: + datetime: A timezone-aware datetime object. + + Raises: + ValueError: If the datetime string is not in the correct format. + """ + datetime_str = datetime_str.upper().strip().rstrip("Z") + try: + dt = datetime.fromisoformat(datetime_str) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=ZoneInfo(time_zone)) + except ValueError as e: + raise ValueError( + f"Invalid datetime format: '{datetime_str}'. " + "Expected ISO 8601 format, e.g., '2024-12-31T15:30:00'." + ) from e + return dt + + +def build_oauth_service(auth_token: str | None) -> Resource: # type: ignore[no-any-unimported] + """ + Build an OAuth2 service object. + """ + auth_token = auth_token or "" + return build("oauth2", "v2", credentials=Credentials(auth_token)) + + +def build_calendar_service(auth_token: str | None) -> Resource: # type: ignore[no-any-unimported] + """ + Build a Calendar service object. + """ + auth_token = auth_token or "" + return build("calendar", "v3", credentials=Credentials(auth_token)) + + +def weekday_to_name(weekday: int) -> str: + return ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"][weekday] + + +def get_time_boundaries_for_date( + current_date: date, + global_start: datetime, + global_end: datetime, + start_time_boundary: time, + end_time_boundary: time, + tz: ZoneInfo, +) -> tuple[datetime, datetime]: + """Compute the allowed start and end times for the given day, adjusting for global bounds.""" + day_start_time = datetime.combine(current_date, start_time_boundary).replace(tzinfo=tz) + day_end_time = datetime.combine(current_date, end_time_boundary).replace(tzinfo=tz) + + if current_date == global_start.date(): + day_start_time = max(day_start_time, global_start) + + if current_date == global_end.date(): + day_end_time = min(day_end_time, global_end) + + return day_start_time, day_end_time + + +def gather_busy_intervals( + busy_data: dict[str, Any], + day_start: datetime, + day_end: datetime, + business_tz: ZoneInfo, +) -> list[tuple[datetime, datetime]]: + """ + Collect busy intervals from all calendars that intersect with the day's business hours. + Busy intervals are clipped to lie within [day_start, day_end]. + """ + busy_intervals = [] + for calendar in busy_data: + for slot in busy_data[calendar].get("busy", []): + slot_start = parse_rfc3339_datetime_str(slot["start"]).astimezone(business_tz) + slot_end = parse_rfc3339_datetime_str(slot["end"]).astimezone(business_tz) + if slot_end > day_start and slot_start < day_end: + busy_intervals.append((max(slot_start, day_start), min(slot_end, day_end))) + return busy_intervals + + +def subtract_busy_intervals( + business_start: datetime, + business_end: datetime, + busy_intervals: list[tuple[datetime, datetime]], +) -> list[dict[str, Any]]: + """ + Subtract the merged busy intervals from the business hours and return free time slots. + """ + free_slots = [] + merged_busy = merge_intervals(busy_intervals) + + # If there are no busy intervals, return the entire business window as free. + if not merged_busy: + return [ + { + "start": { + "datetime": business_start.isoformat(), + "weekday": weekday_to_name(business_start.weekday()), + }, + "end": { + "datetime": business_end.isoformat(), + "weekday": weekday_to_name(business_end.weekday()), + }, + } + ] + + current_free_start = business_start + for busy_start, busy_end in merged_busy: + if current_free_start < busy_start: + free_slots.append({ + "start": { + "datetime": current_free_start.isoformat(), + "weekday": weekday_to_name(current_free_start.weekday()), + }, + "end": { + "datetime": busy_start.isoformat(), + "weekday": weekday_to_name(busy_start.weekday()), + }, + }) + current_free_start = max(current_free_start, busy_end) + if current_free_start < business_end: + free_slots.append({ + "start": { + "datetime": current_free_start.isoformat(), + "weekday": weekday_to_name(current_free_start.weekday()), + }, + "end": { + "datetime": business_end.isoformat(), + "weekday": weekday_to_name(business_end.weekday()), + }, + }) + return free_slots + + +def compute_free_time_intersection( + busy_data: dict[str, Any], + global_start: datetime, + global_end: datetime, + start_time_boundary: time, + end_time_boundary: time, + include_weekends: bool, + tz: ZoneInfo, +) -> list[dict[str, Any]]: + """ + Returns the free time slots across all calendars within the global bounds, + ensuring that the global start is not in the past. + + Only considers business days (Monday to Friday) and business hours (08:00-19:00) + in the provided timezone. + """ + # Ensure global_start is never in the past relative to now. + now = get_now(tz) + + if now > global_start: + global_start = now + + # If after adjusting the start, there's no interval left, return empty. + if global_start >= global_end: + return [] + + free_slots = [] + current_date = global_start.date() + + while current_date <= global_end.date(): + if not include_weekends and current_date.weekday() >= 5: + current_date += timedelta(days=1) + continue + + day_start, day_end = get_time_boundaries_for_date( + current_date=current_date, + global_start=global_start, + global_end=global_end, + start_time_boundary=start_time_boundary, + end_time_boundary=end_time_boundary, + tz=tz, + ) + + # Skip if the day's allowed time window is empty. + if day_start >= day_end: + current_date += timedelta(days=1) + continue + + busy_intervals = gather_busy_intervals(busy_data, day_start, day_end, tz) + free_slots.extend(subtract_busy_intervals(day_start, day_end, busy_intervals)) + + current_date += timedelta(days=1) + + return free_slots + + +def get_now(tz: ZoneInfo | None = None) -> datetime: + if not tz: + tz = ZoneInfo("UTC") + return datetime.now(tz) + + +def parse_rfc3339_datetime_str(dt_str: str, tz: timezone = timezone.utc) -> datetime: + """ + Parse an RFC3339 datetime string into a timezone-aware datetime. + Converts a trailing 'Z' (UTC) into +00:00. + If the parsed datetime is naive, assume it is in the provided timezone. + """ + if dt_str.endswith("Z"): + dt_str = dt_str[:-1] + "+00:00" + dt = datetime.fromisoformat(dt_str) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=tz) + return dt + + +def merge_intervals(intervals: list[tuple[datetime, datetime]]) -> list[tuple[datetime, datetime]]: + """ + Given a list of (start, end) tuples, merge overlapping or adjacent intervals. + """ + merged: list[tuple[datetime, datetime]] = [] + for start, end in sorted(intervals, key=lambda x: x[0]): + if not merged: + merged.append((start, end)) + else: + last_start, last_end = merged[-1] + if start <= last_end: + merged[-1] = (last_start, max(last_end, end)) + else: + merged.append((start, end)) + return merged diff --git a/toolkits/google_calendar/evals/eval_google_calendar.py b/toolkits/google_calendar/evals/eval_google_calendar.py new file mode 100644 index 00000000..665b8adf --- /dev/null +++ b/toolkits/google_calendar/evals/eval_google_calendar.py @@ -0,0 +1,215 @@ +from datetime import timedelta + +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_google_calendar +from arcade_google_calendar.enums import EventVisibility, SendUpdatesOptions +from arcade_google_calendar.tools import ( + create_event, + delete_event, + list_calendars, + list_events, + update_event, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_module(arcade_google_calendar) + +history_after_list_events = [ + {"role": "user", "content": "do i have any events on my calendar for today?"}, + { + "role": "assistant", + "content": "Please go to this URL and authorize the action: \n[Link](https://accounts.google.com/o/oauth2/v2/auth?)", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_uHdRlr4z7sFm39ZrPsE5wcdT", + "type": "function", + "function": { + "name": "GoogleCalendar_ListEvents", + "arguments": '{"min_end_datetime":"2024-09-26T00:00:00-07:00","max_start_datetime":"2024-09-27T00:00:00-07:00"}', + }, + } + ], + }, + { + "role": "tool", + "content": '{"events_count": 3, "events": [{"creator": {"email": "john@example.com", "self": true}, "description": "1:1 meeting with Joe", "end": {"dateTime": "2024-09-26T00:15:00-07:00", "timeZone": "America/Los_Angeles"}, "eventType": "default", "htmlLink": "https://www.google.com/calendar/event?eid=01234", "id": "10009199283838467", "location": "622 Rainbow Ave, South San Francisco, CA 94080, USA", "organizer": {"email": "john@example.com", "self": true}, "start": {"dateTime": "2024-09-25T23:15:00-07:00", "timeZone": "America/Los_Angeles"}, "summary": "1:1 meeting"}, {"attendees": [{"email": "joe@example.com", "responseStatus": "accepted"}], "creator": {"email": "john@example.com", "self": true}, "description": "This is just a test", "end": {"dateTime": "2024-09-26T14:00:00-07:00", "timeZone": "America/Los_Angeles"}, "eventType": "default", "htmlLink": "https://www.google.com/calendar/event?eid=OXB2OGFwcmZraWk1N234324", "id": "00099992228181818181", "organizer": {"email": "john@example.com", "self": true}, "start": {"dateTime": "2024-09-26T12:00:00-07:00", "timeZone": "America/Los_Angeles"}, "summary": "API test"}, {"attendees": [{"email": "henry@example.com", "responseStatus": "needsAction"}], "creator": {"email": "john@example.com", "self": true}, "end": {"dateTime": "2024-09-26T17:00:00-07:00", "timeZone": "America/Los_Angeles"}, "eventType": "default", "htmlLink": "https://www.google.com/calendar/event?eid=Z3I1ZzE4b324534556", "id": "gr5g18lf88tfpp3vkareukkc7g", "location": "611 Rainbow Road", "organizer": {"email": "john@example.com", "self": true}, "start": {"dateTime": "2024-09-26T15:00:00-07:00", "timeZone": "America/Los_Angeles"}, "summary": "Focus Time", "visibility": "public"}]}', + "tool_call_id": "call_uHdRlr4z7sFm39ZrPsE5wcdT", + "name": "GoogleCalendar_ListEvents", + }, + { + "role": "assistant", + "content": "Yes, you have three events on your calendar for today:\n\n1. **Event:** Test2\n - **Time:** 23:15 - 00:15 (PST)\n - **Location:** 611 Gateway Blvd, South San Francisco, CA 94080, USA\n - **Description:** 1:1 meeting with Joe\n 2. **Event:** API Test\n - **Time:** 12:00 - 14:00 (PST)\n **Description:** This is just a test\n - [View Event](https://www.google.com/calendar/event?eid=OXB2OGFwcmZraWk1NGUwa24xaTNya2lvZDggZXJpY0BhcmNhZGUtYWkuY29t)\n\n3. **Event:** Focus Time\n - **Time:** 15:00 - 17:00 (PST)\n - **Location:** 611 Rainbow Road\n - **Visibility:** Public\n - [View Event](https://www.google.com/calendar/event?eid=Z3I1ZzE4bGY4OHRmcHAzdmthcmV1a2tjN2cgZXJpY0BhcmNhZGUtYWkuY29t)\n\nIf you need more details or help with anything else, feel free to ask!", + }, +] + + +@tool_eval() +def calendar_eval_suite() -> EvalSuite: + """Create an evaluation suite for Calendar tools.""" + suite = EvalSuite( + name="Calendar Tools Evaluation", + system_message=( + "You are an AI assistant that can create, list, update, and delete events using the provided tools. Today is 2024-09-26" + ), + catalog=catalog, + rubric=rubric, + ) + + # Cases for list_calendars + suite.add_case( + name="List Calendars", + user_message=("What calendars do I have?"), + expected_tool_calls=[ + ExpectedToolCall( + func=list_calendars, + args={}, + ) + ], + critics=[], + ) + + # Cases for create_event + suite.add_case( + name="Create calendar event", + user_message=( + "Create a meeting for 'Team Meeting' starting on September 26, 2024, from 11:45pm to 12:15am. Invite johndoe@example.com" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_event, + args={ + "summary": "Team Meeting", + "start_datetime": "2024-09-26T23:45:00", + "end_datetime": "2024-09-27T00:15:00", + "calendar_id": "primary", + "attendee_emails": ["johndoe@example.com"], + "visibility": EventVisibility.DEFAULT, + "description": "Team Meeting", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="summary", weight=0.2), + DatetimeCritic( + critic_field="start_datetime", weight=0.2, tolerance=timedelta(seconds=10) + ), + DatetimeCritic( + critic_field="end_datetime", weight=0.2, tolerance=timedelta(seconds=10) + ), + BinaryCritic(critic_field="attendee_emails", weight=0.2), + BinaryCritic(critic_field="description", weight=0.1), + BinaryCritic(critic_field="location", weight=0.1), + ], + ) + + # Cases for list_events + suite.add_case( + name="List calendar events", + user_message="Do I have any events on my calendar today?", + expected_tool_calls=[ + ExpectedToolCall( + func=list_events, + args={ + "min_end_datetime": "2024-09-26T00:00:00", + "max_start_datetime": "2024-09-27T00:00:00", + "calendar_id": "primary", + "max_results": 10, + }, + ) + ], + critics=[ + DatetimeCritic( + critic_field="min_end_datetime", weight=0.3, tolerance=timedelta(hours=1) + ), + DatetimeCritic( + critic_field="max_start_datetime", weight=0.3, tolerance=timedelta(hours=1) + ), + BinaryCritic(critic_field="calendar_id", weight=0.2), + BinaryCritic(critic_field="max_results", weight=0.2), + ], + ) + + # Cases for update_event + suite.add_case( + name="Update a calendar event", + user_message=( + "Oh no! I can't make it to the API Test since I have lunch with an old friend at that time. " + "Change my meeting tomorrow at 3pm to 4pm. Let everyone know." + ), + expected_tool_calls=[ + ExpectedToolCall( + func=update_event, + args={ + "event_id": "00099992228181818181", + "updated_start_datetime": "2024-09-27T16:00:00", + "updated_end_datetime": "2024-09-27T18:00:00", + "updated_calendar_id": "primary", + "updated_summary": "API Test", + "updated_description": "API Test", + "updated_location": "611 Gateway Blvd", + "updated_visibility": EventVisibility.DEFAULT, + "attendee_emails_to_add": None, + "attendee_emails_to_remove": None, + "send_updates": SendUpdatesOptions.ALL, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="event_id", weight=0.4), + DatetimeCritic( + critic_field="updated_start_datetime", weight=0.2, tolerance=timedelta(minutes=15) + ), + DatetimeCritic( + critic_field="updated_end_datetime", + weight=0.2, + tolerance=timedelta(minutes=15), + ), + BinaryCritic(critic_field="send_updates", weight=0.2), + ], + additional_messages=history_after_list_events, + ) + + # Cases for delete_event + suite.add_case( + name="Delete a calendar event", + user_message=( + "I don't need to have focus time today. Please delete it from my calendar. Don't send any notifications." + ), + expected_tool_calls=[ + ExpectedToolCall( + func=delete_event, + args={ + "event_id": "gr5g18lf88tfpp3vkareukkc7g", + "calendar_id": "primary", + "send_updates": SendUpdatesOptions.NONE, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="event_id", weight=0.6), + BinaryCritic(critic_field="calendar_id", weight=0.2), + BinaryCritic(critic_field="send_updates", weight=0.2), + ], + additional_messages=history_after_list_events, + ) + + return suite diff --git a/toolkits/google_calendar/pyproject.toml b/toolkits/google_calendar/pyproject.toml new file mode 100644 index 00000000..38b8f076 --- /dev/null +++ b/toolkits/google_calendar/pyproject.toml @@ -0,0 +1,63 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_calendar" +version = "2.0.0" +description = "Arcade.dev LLM tools for Google Calendar" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "google-api-core>=2.19.1,<3.0.0", + "google-api-python-client>=2.137.0,<3.0.0", + "google-auth>=2.32.0,<3.0.0", + "google-auth-httplib2>=0.2.0,<1.0.0", + "googleapis-common-protos>=1.63.2,<2.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0rc1,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_calendar/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_calendar",] diff --git a/toolkits/google_calendar/tests/__init__.py b/toolkits/google_calendar/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_calendar/tests/test_calendar.py b/toolkits/google_calendar/tests/test_calendar.py new file mode 100644 index 00000000..82d1be1c --- /dev/null +++ b/toolkits/google_calendar/tests/test_calendar.py @@ -0,0 +1,582 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch +from zoneinfo import ZoneInfo + +import pytest +from arcade_tdk import ToolAuthorizationContext, ToolContext +from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from googleapiclient.errors import HttpError + +from arcade_google_calendar.enums import EventVisibility, SendUpdatesOptions +from arcade_google_calendar.tools import ( + create_event, + delete_event, + find_time_slots_when_everyone_is_free, + list_calendars, + list_events, + update_event, +) + + +@pytest.fixture +def mock_context(): + mock_auth = ToolAuthorizationContext(token="fake-token") # noqa: S106 + return ToolContext(authorization=mock_auth) + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_list_calendars(mock_build_calendar_service, mock_context): + mock_service = MagicMock() + mock_build_calendar_service.return_value = mock_service + + expected_api_response = { + "etag": '"p33for2n0pvc8o0o"', + "items": [ + { + "accessRole": "reader", + "backgroundColor": "#16a765", + "colorId": "8", + "conferenceProperties": {"allowedConferenceSolutionTypes": ["hangoutsMeet"]}, + "defaultReminders": [], + "description": "Holidays and Observances in Brazil", + "etag": '"2347287866334000"', + "foregroundColor": "#000000", + "id": "en.brazilian#holiday@group.v.calendar.google.com", + "kind": "calendar#calendarListEntry", + "selected": True, + "summary": "Holidays in Brazil", + "timeZone": "America/Sao_Paulo", + }, + { + "accessRole": "owner", + "backgroundColor": "#9fe1e7", + "colorId": "14", + "conferenceProperties": {"allowedConferenceSolutionTypes": ["hangoutsMeet"]}, + "defaultReminders": [{"method": "popup", "minutes": 10}], + "etag": '"1743169667849567"', + "foregroundColor": "#000000", + "id": "example@arcade.dev", + "kind": "calendar#calendarListEntry", + "notificationSettings": { + "notifications": [ + {"method": "email", "type": "eventCreation"}, + {"method": "email", "type": "eventChange"}, + {"method": "email", "type": "eventCancellation"}, + {"method": "email", "type": "eventResponse"}, + ] + }, + "primary": True, + "selected": True, + "summary": "example@arcade.dev", + "timeZone": "America/Sao_Paulo", + }, + ], + "kind": "calendar#calendarList", + "nextSyncToken": "XkJ8Hy5mN2pQvL9sR4tW7cA3fE1iU6nB", + } + + expected_tool_response = { + "num_calendars": 2, + "calendars": [ + { + "description": "Holidays and Observances in Brazil", + "id": "en.brazilian#holiday@group.v.calendar.google.com", + "summary": "Holidays in Brazil", + "timeZone": "America/Sao_Paulo", + }, + { + "id": "example@arcade.dev", + "summary": "example@arcade.dev", + "timeZone": "America/Sao_Paulo", + }, + ], + "next_page_token": None, + } + + mock_service.calendarList().list().execute.return_value = expected_api_response + + response = await list_calendars(context=mock_context) + assert response == expected_tool_response + + # Case: HttpError during calendars listing + mock_service.calendarList().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await list_calendars(context=mock_context) + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_create_event(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Mock the calendar's time zone + mock_service.calendars().get().execute.return_value = {"timeZone": "America/Los_Angeles"} + + # Case: HttpError during event creation + mock_service.events().insert().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await create_event( + context=mock_context, + summary="Test Event", + start_datetime="2024-12-31T15:30:00", + end_datetime="2024-12-31T17:30:00", + description="Test Description", + location="Test Location", + visibility=EventVisibility.PRIVATE, + attendee_emails=["test@example.com"], + ) + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_list_events(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + # Mock the calendar's time zone + mock_service.calendars().get().execute.return_value = {"timeZone": "America/Los_Angeles"} + + # Mock the events list response + mock_events_list_response = { + "items": [ + { + "creator": {"email": "example@arcade-ai.com", "self": True}, + "end": {"dateTime": "2024-09-27T01:00:00-07:00", "timeZone": "America/Los_Angeles"}, + "eventType": "default", + "htmlLink": "https://www.google.com/calendar/event?eid=event1", + "id": "event1", + "organizer": {"email": "example@arcade-ai.com", "self": True}, + "start": { + "dateTime": "2024-09-27T00:00:00-07:00", + "timeZone": "America/Los_Angeles", + }, + "summary": "Event 1", + }, + { + "creator": {"email": "example@arcade-ai.com", "self": True}, + "end": {"dateTime": "2024-09-27T17:00:00-07:00", "timeZone": "America/Los_Angeles"}, + "eventType": "default", + "htmlLink": "https://www.google.com/calendar/event?eid=event2", + "id": "event2", + "organizer": {"email": "example@arcade-ai.com", "self": True}, + "start": { + "dateTime": "2024-09-27T14:00:00-07:00", + "timeZone": "America/Los_Angeles", + }, + "summary": "Event 2", + }, + ] + } + expected_tool_response = { + "events_count": len(mock_events_list_response["items"]), + "events": mock_events_list_response["items"], + } + mock_service.events().list().execute.return_value = mock_events_list_response + response = await list_events( + context=mock_context, + min_end_datetime="2024-09-15T09:00:00", + max_start_datetime="2024-09-16T17:00:00", + ) + assert response == expected_tool_response + + # Case: HttpError during events listing + mock_service.events().list().execute.side_effect = HttpError( + resp=MagicMock(status=400), + content=b'{"error": {"message": "Invalid request"}}', + ) + + with pytest.raises(ToolExecutionError): + await list_events( + context=mock_context, + min_end_datetime="2024-09-15T09:00:00", + max_start_datetime="2024-09-16T17:00:00", + ) + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_update_event(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + + # Mock retrieval of the event + mock_service.events().get().execute.side_effect = HttpError( + resp=MagicMock(status=404), + content=b'{"error": {"message": "Event not found"}}', + ) + + with pytest.raises(ToolExecutionError): + await update_event( + context=mock_context, + event_id="1234567890", + updated_start_datetime="2024-12-31T00:15:00", + updated_end_datetime="2024-12-31T01:15:00", + updated_summary="Updated Event", + updated_description="Updated Description", + updated_location="Updated Location", + updated_visibility=EventVisibility.PRIVATE, + attendee_emails_to_add=["test@example.com"], + attendee_emails_to_remove=["test@example2.com"], + send_updates=SendUpdatesOptions.ALL, + ) + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_delete_event(mock_build, mock_context): + mock_service = MagicMock() + mock_build.return_value = mock_service + mock_service.events().delete().execute.side_effect = HttpError( + resp=MagicMock(status=404), + content=b'{"error": {"message": "Event not found"}}', + ) + + with pytest.raises(ToolExecutionError): + await delete_event( + context=mock_context, + event_id="nonexistent_event", + send_updates=SendUpdatesOptions.ALL, + ) + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.utils.get_now") +@patch("arcade_google_calendar.tools.calendar.build_oauth_service") +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_find_free_slots_happiest_path_single_user( + mock_build_calendar_service, mock_build_oauth_service, mock_get_now, mock_context +): + calendar_service = MagicMock() + oauth_service = MagicMock() + + mock_get_now.return_value = datetime( + 2025, 3, 10, 9, 25, 0, tzinfo=ZoneInfo("America/Los_Angeles") + ) + mock_build_oauth_service.return_value = oauth_service + mock_build_calendar_service.return_value = calendar_service + + oauth_service.userinfo().get().execute.return_value = { + "email": "example@arcade.dev", + } + + calendar_service.freebusy().query().execute.return_value = { + "calendars": { + "example@arcade.dev": {"busy": []}, + } + } + + calendar_service.calendars().get().execute.return_value = { + "timeZone": "America/Los_Angeles", + } + + response = await find_time_slots_when_everyone_is_free( + context=mock_context, + email_addresses=["example@arcade.dev"], + start_date="2025-03-10", + end_date="2025-03-11", + start_time_boundary="08:00", + end_time_boundary="18:00", + ) + + assert response == { + "free_slots": [ + { + "start": { + "datetime": "2025-03-10T09:25:00-07:00", + "weekday": "Monday", + }, + "end": { + "datetime": "2025-03-10T18:00:00-07:00", + "weekday": "Monday", + }, + }, + { + "start": { + "datetime": "2025-03-11T08:00:00-07:00", + "weekday": "Tuesday", + }, + "end": { + "datetime": "2025-03-11T18:00:00-07:00", + "weekday": "Tuesday", + }, + }, + ], + "timezone": "America/Los_Angeles", + } + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.utils.get_now") +@patch("arcade_google_calendar.tools.calendar.build_oauth_service") +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_find_free_slots_happiest_path_single_user_with_busy_times( + mock_build_calendar_service, mock_build_oauth_service, mock_get_now, mock_context +): + calendar_service = MagicMock() + oauth_service = MagicMock() + + mock_get_now.return_value = datetime( + 2025, 3, 10, 9, 25, 0, tzinfo=ZoneInfo("America/Los_Angeles") + ) + + mock_build_oauth_service.return_value = oauth_service + mock_build_calendar_service.return_value = calendar_service + + oauth_service.userinfo().get().execute.return_value = { + "email": "example@arcade.dev", + } + + calendar_service.freebusy().query().execute.return_value = { + "calendars": { + "example@arcade.dev": { + "busy": [ + { + "start": "2025-03-10T11:00:00-07:00", + "end": "2025-03-10T12:00:00-07:00", + }, + { + "start": "2025-03-10T14:15:00-07:00", + "end": "2025-03-10T14:30:00-07:00", + }, + ] + }, + } + } + + calendar_service.calendars().get().execute.return_value = { + "timeZone": "America/Los_Angeles", + } + + response = await find_time_slots_when_everyone_is_free( + context=mock_context, + email_addresses=["example@arcade.dev"], + start_date="2025-03-10", + end_date="2025-03-11", + start_time_boundary="08:00", + end_time_boundary="18:00", + ) + + assert response == { + "free_slots": [ + { + "start": { + "datetime": "2025-03-10T09:25:00-07:00", + "weekday": "Monday", + }, + "end": { + "datetime": "2025-03-10T11:00:00-07:00", + "weekday": "Monday", + }, + }, + { + "start": { + "datetime": "2025-03-10T12:00:00-07:00", + "weekday": "Monday", + }, + "end": { + "datetime": "2025-03-10T14:15:00-07:00", + "weekday": "Monday", + }, + }, + { + "start": { + "datetime": "2025-03-10T14:30:00-07:00", + "weekday": "Monday", + }, + "end": { + "datetime": "2025-03-10T18:00:00-07:00", + "weekday": "Monday", + }, + }, + { + "start": { + "datetime": "2025-03-11T08:00:00-07:00", + "weekday": "Tuesday", + }, + "end": { + "datetime": "2025-03-11T18:00:00-07:00", + "weekday": "Tuesday", + }, + }, + ], + "timezone": "America/Los_Angeles", + } + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.utils.get_now") +@patch("arcade_google_calendar.tools.calendar.build_oauth_service") +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_find_free_slots_happiest_path_multiple_users_with_busy_times( + mock_build_calendar_service, mock_build_oauth_service, mock_get_now, mock_context +): + calendar_service = MagicMock() + oauth_service = MagicMock() + + mock_get_now.return_value = datetime( + 2025, 3, 10, 9, 25, 0, tzinfo=ZoneInfo("America/Los_Angeles") + ) + + mock_build_oauth_service.return_value = oauth_service + mock_build_calendar_service.return_value = calendar_service + + oauth_service.userinfo().get().execute.return_value = { + "email": "example@arcade.dev", + } + + calendar_service.freebusy().query().execute.return_value = { + "calendars": { + "example@arcade.dev": { + "busy": [ + { + "start": "2025-03-10T11:00:00-07:00", + "end": "2025-03-10T12:00:00-07:00", + }, + { + "start": "2025-03-10T14:15:00-07:00", + "end": "2025-03-10T14:30:00-07:00", + }, + ] + }, + "example2@arcade.dev": { + "busy": [ + { + "start": "2025-03-10T11:30:00-07:00", + "end": "2025-03-10T12:45:00-07:00", + }, + { + "start": "2025-03-11T06:00:00-07:00", + "end": "2025-03-11T07:00:00-07:00", + }, + ] + }, + } + } + + calendar_service.calendars().get().execute.return_value = { + "timeZone": "America/Los_Angeles", + } + + response = await find_time_slots_when_everyone_is_free( + context=mock_context, + email_addresses=["example@arcade.dev", "example2@arcade.dev"], + start_date="2025-03-10", + end_date="2025-03-11", + start_time_boundary="08:00", + end_time_boundary="18:00", + ) + + assert response == { + "free_slots": [ + { + "start": { + "datetime": "2025-03-10T09:25:00-07:00", + "weekday": "Monday", + }, + "end": { + "datetime": "2025-03-10T11:00:00-07:00", + "weekday": "Monday", + }, + }, + { + "start": { + "datetime": "2025-03-10T12:45:00-07:00", + "weekday": "Monday", + }, + "end": { + "datetime": "2025-03-10T14:15:00-07:00", + "weekday": "Monday", + }, + }, + { + "start": { + "datetime": "2025-03-10T14:30:00-07:00", + "weekday": "Monday", + }, + "end": { + "datetime": "2025-03-10T18:00:00-07:00", + "weekday": "Monday", + }, + }, + { + "start": { + "datetime": "2025-03-11T08:00:00-07:00", + "weekday": "Tuesday", + }, + "end": { + "datetime": "2025-03-11T18:00:00-07:00", + "weekday": "Tuesday", + }, + }, + ], + "timezone": "America/Los_Angeles", + } + + +@pytest.mark.asyncio +@patch("arcade_google_calendar.utils.get_now") +@patch("arcade_google_calendar.tools.calendar.build_oauth_service") +@patch("arcade_google_calendar.tools.calendar.build_calendar_service") +async def test_find_free_slots_with_google_calendar_error_not_found( + mock_build_calendar_service, mock_build_oauth_service, mock_get_now, mock_context +): + calendar_service = MagicMock() + oauth_service = MagicMock() + + mock_get_now.return_value = datetime( + 2025, 3, 10, 9, 25, 0, tzinfo=ZoneInfo("America/Los_Angeles") + ) + mock_build_oauth_service.return_value = oauth_service + mock_build_calendar_service.return_value = calendar_service + + oauth_service.userinfo().get().execute.return_value = { + "email": "example@arcade.dev", + } + + calendar_service.freebusy().query().execute.return_value = { + "calendars": { + "example@arcade.dev": { + "busy": [ + { + "start": "2025-03-10T11:00:00-07:00", + "end": "2025-03-10T12:00:00-07:00", + }, + { + "start": "2025-03-10T14:15:00-07:00", + "end": "2025-03-10T14:30:00-07:00", + }, + ] + }, + "example2@arcade.dev": { + "errors": [ + { + "reason": "notFound", + "domain": "calendar", + } + ] + }, + } + } + + calendar_service.calendars().get().execute.return_value = { + "timeZone": "America/Los_Angeles", + } + + with pytest.raises(RetryableToolError): + await find_time_slots_when_everyone_is_free( + context=mock_context, + email_addresses=["example@arcade.dev", "example2@arcade.dev"], + start_date="2025-03-10", + end_date="2025-03-11", + start_time_boundary="08:00", + end_time_boundary="18:00", + ) diff --git a/toolkits/google_contacts/.pre-commit-config.yaml b/toolkits/google_contacts/.pre-commit-config.yaml new file mode 100644 index 00000000..0cf8d087 --- /dev/null +++ b/toolkits/google_contacts/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_contacts/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_contacts/.ruff.toml b/toolkits/google_contacts/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/google_contacts/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_contacts/Makefile b/toolkits/google_contacts/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_contacts/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_contacts/arcade_google_contacts/__init__.py b/toolkits/google_contacts/arcade_google_contacts/__init__.py new file mode 100644 index 00000000..32b78c1c --- /dev/null +++ b/toolkits/google_contacts/arcade_google_contacts/__init__.py @@ -0,0 +1,7 @@ +from arcade_google_contacts.tools import ( + create_contact, + search_contacts_by_email, + search_contacts_by_name, +) + +__all__ = ["create_contact", "search_contacts_by_email", "search_contacts_by_name"] diff --git a/toolkits/google_contacts/arcade_google_contacts/constants.py b/toolkits/google_contacts/arcade_google_contacts/constants.py new file mode 100644 index 00000000..580d7f09 --- /dev/null +++ b/toolkits/google_contacts/arcade_google_contacts/constants.py @@ -0,0 +1 @@ +DEFAULT_SEARCH_CONTACTS_LIMIT = 30 diff --git a/toolkits/google_contacts/arcade_google_contacts/tools/__init__.py b/toolkits/google_contacts/arcade_google_contacts/tools/__init__.py new file mode 100644 index 00000000..e38af883 --- /dev/null +++ b/toolkits/google_contacts/arcade_google_contacts/tools/__init__.py @@ -0,0 +1,7 @@ +from arcade_google_contacts.tools.contacts import ( + create_contact, + search_contacts_by_email, + search_contacts_by_name, +) + +__all__ = ["create_contact", "search_contacts_by_email", "search_contacts_by_name"] diff --git a/toolkits/google_contacts/arcade_google_contacts/tools/contacts.py b/toolkits/google_contacts/arcade_google_contacts/tools/contacts.py new file mode 100644 index 00000000..6a0791dd --- /dev/null +++ b/toolkits/google_contacts/arcade_google_contacts/tools/contacts.py @@ -0,0 +1,96 @@ +import asyncio +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Google + +from arcade_google_contacts.constants import DEFAULT_SEARCH_CONTACTS_LIMIT +from arcade_google_contacts.utils import build_people_service, search_contacts + + +async def _warmup_cache(service) -> None: # type: ignore[no-untyped-def] + """ + Warm-up the search cache for contacts by sending a request with an empty query. + This ensures that the lazy cache is updated for both primary contacts and other contacts. + This is unfortunately a real thing: https://developers.google.com/people/v1/contacts#search_the_users_contacts + """ + service.people().searchContacts(query="", pageSize=1, readMask="names,emailAddresses").execute() + await asyncio.sleep(3) # TODO experiment with this value + + +@tool(requires_auth=Google(scopes=["https://www.googleapis.com/auth/contacts.readonly"])) +async def search_contacts_by_email( + context: ToolContext, + email: Annotated[str, "The email address to search for"], + limit: Annotated[ + int | None, + "The maximum number of contacts to return (30 is the max allowed by Google API)", + ] = DEFAULT_SEARCH_CONTACTS_LIMIT, +) -> Annotated[dict, "A dictionary containing the list of matching contacts"]: + """ + Search the user's contacts in Google Contacts by email address. + """ + service = build_people_service(context.get_auth_token_or_empty()) + # Warm-up the cache before performing search. + # TODO: Ideally we should warmup only if this user (or google domain?) hasn't warmed up recently + await _warmup_cache(service) + + return {"contacts": search_contacts(service, email, limit)} + + +@tool(requires_auth=Google(scopes=["https://www.googleapis.com/auth/contacts.readonly"])) +async def search_contacts_by_name( + context: ToolContext, + name: Annotated[str, "The full name to search for"], + limit: Annotated[ + int | None, + "The maximum number of contacts to return (30 is the max allowed by Google API)", + ] = DEFAULT_SEARCH_CONTACTS_LIMIT, +) -> Annotated[dict, "A dictionary containing the list of matching contacts"]: + """ + Search the user's contacts in Google Contacts by name. + """ + service = build_people_service(context.get_auth_token_or_empty()) + # Warm-up the cache before performing search. + # TODO: Ideally we should warmup only if this user (or google domain?) hasn't warmed up recently + await _warmup_cache(service) + return {"contacts": search_contacts(service, name, limit)} + + +@tool(requires_auth=Google(scopes=["https://www.googleapis.com/auth/contacts"])) +async def create_contact( + context: ToolContext, + given_name: Annotated[str, "The given name of the contact"], + family_name: Annotated[str | None, "The optional family name of the contact"], + email: Annotated[str | None, "The optional email address of the contact"], +) -> Annotated[dict, "A dictionary containing the details of the created contact"]: + """ + Create a new contact record in Google Contacts. + + Examples: + ``` + create_contact(given_name="Alice") + create_contact(given_name="Alice", family_name="Smith") + create_contact(given_name="Alice", email="alice@example.com") + ``` + """ + # Build the People API service + service = build_people_service(context.get_auth_token_or_empty()) + + # Construct the person payload with the specified names + name_body = {"givenName": given_name} + if family_name: + name_body["familyName"] = family_name + contact_body = {"names": [name_body]} + if email: + contact_body["emailAddresses"] = [{"value": email, "type": "work"}] + + # Create the contact. The personFields parameter specifies what information + # should be returned. Here, we return names and emailAddresses. + created_contact = ( + service.people() + .createContact(body=contact_body, personFields="names,emailAddresses") + .execute() + ) + + return {"contact": created_contact} diff --git a/toolkits/google_contacts/arcade_google_contacts/utils.py b/toolkits/google_contacts/arcade_google_contacts/utils.py new file mode 100644 index 00000000..159eb0af --- /dev/null +++ b/toolkits/google_contacts/arcade_google_contacts/utils.py @@ -0,0 +1,49 @@ +import logging +from typing import Any, cast + +from google.oauth2.credentials import Credentials +from googleapiclient.discovery import Resource, build + +from arcade_google_contacts.constants import DEFAULT_SEARCH_CONTACTS_LIMIT + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +def build_people_service(auth_token: str | None) -> Resource: # type: ignore[no-any-unimported] + """ + Build a People service object. + """ + auth_token = auth_token or "" + return build("people", "v1", credentials=Credentials(auth_token)) + + +def search_contacts(service: Any, query: str, limit: int | None) -> list[dict[str, Any]]: + """ + Search the user's contacts in Google Contacts. + """ + response = ( + service.people() + .searchContacts( + query=query, + pageSize=limit or DEFAULT_SEARCH_CONTACTS_LIMIT, + readMask=",".join([ + "names", + "nicknames", + "emailAddresses", + "phoneNumbers", + "addresses", + "organizations", + "biographies", + "urls", + "userDefined", + ]), + ) + .execute() + ) + + return cast(list[dict[str, Any]], response.get("results", [])) diff --git a/toolkits/google_contacts/evals/eval_google_contacts.py b/toolkits/google_contacts/evals/eval_google_contacts.py new file mode 100644 index 00000000..83c7d294 --- /dev/null +++ b/toolkits/google_contacts/evals/eval_google_contacts.py @@ -0,0 +1,135 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_google_contacts +from arcade_google_contacts.tools import ( + create_contact, + search_contacts_by_email, + search_contacts_by_name, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_module(arcade_google_contacts) + + +@tool_eval() +def contacts_eval_suite() -> EvalSuite: + """Create an evaluation suite for Google Contacts tools.""" + suite = EvalSuite( + name="Google Contacts Tools Evaluation", + system_message="You are an AI assistant that can manage Google Contacts using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Search contacts by name", + user_message="Find my contact Bob", + expected_tool_calls=[ + ExpectedToolCall( + func=search_contacts_by_name, + args={ + "name": "Bob", + }, + ) + ], + ) + + suite.add_case( + name="Search contacts by email", + user_message="Find my contact alice@example.com", + expected_tool_calls=[ + ExpectedToolCall( + func=search_contacts_by_email, + args={ + "email": "alice@example.com", + }, + ) + ], + ) + + suite.add_case( + name="Search contacts with query and limit", + user_message="Find 5 contacts whose names include 'Alice'", + expected_tool_calls=[ + ExpectedToolCall( + func=search_contacts_by_name, + args={ + "name": "Alice", + "limit": 5, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="query", weight=0.5), + BinaryCritic(critic_field="limit", weight=0.5), + ], + ) + + suite.add_case( + name="Create new contact with only given name", + user_message="Create a new contact for Alice", + expected_tool_calls=[ + ExpectedToolCall( + func=create_contact, + args={ + "given_name": "Alice", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="given_name", weight=1.0), + ], + ) + + suite.add_case( + name="Create new contact with only email (infer name from email)", + user_message="Create a new contact for alice@example.com", + expected_tool_calls=[ + ExpectedToolCall( + func=create_contact, + args={ + "given_name": "Alice", + "email": "alice@example.com", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="email", weight=0.5), + BinaryCritic(critic_field="given_name", weight=0.5), + ], + ) + + suite.add_case( + name="Create new contact with full name and email", + user_message="Create a contact for Bob Smith (bob.smith@example.com)", + expected_tool_calls=[ + ExpectedToolCall( + func=create_contact, + args={ + "given_name": "Bob", + "family_name": "Smith", + "email": "bob.smith@example.com", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="given_name", weight=0.33), + BinaryCritic(critic_field="family_name", weight=0.33), + BinaryCritic(critic_field="email", weight=0.34), + ], + ) + + return suite diff --git a/toolkits/google_contacts/pyproject.toml b/toolkits/google_contacts/pyproject.toml new file mode 100644 index 00000000..01d0a71a --- /dev/null +++ b/toolkits/google_contacts/pyproject.toml @@ -0,0 +1,63 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_contacts" +version = "2.0.0" +description = "Arcade.dev LLM tools for Google Contacts" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "google-api-core>=2.19.1,<3.0.0", + "google-api-python-client>=2.137.0,<3.0.0", + "google-auth>=2.32.0,<3.0.0", + "google-auth-httplib2>=0.2.0,<1.0.0", + "googleapis-common-protos>=1.63.2,<2.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0rc1,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_contacts/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_contacts",] diff --git a/toolkits/google_contacts/tests/__init__.py b/toolkits/google_contacts/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_contacts/tests/test_contacts.py b/toolkits/google_contacts/tests/test_contacts.py new file mode 100644 index 00000000..04a4d865 --- /dev/null +++ b/toolkits/google_contacts/tests/test_contacts.py @@ -0,0 +1,100 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from arcade_tdk import ToolContext + +from arcade_google_contacts.tools import create_contact + + +@pytest.fixture +def mock_context(): + context = AsyncMock(spec=ToolContext) + context.authorization = MagicMock() + context.authorization.token = "mock_token" # noqa: S105 + return context + + +@pytest.mark.asyncio +async def test_create_contact_success(mock_context): + # Test create_contact with all parameters (given, family names and email) + created_contact_data = {"resourceName": "people/123", "etag": "abc"} + + create_contact_call = MagicMock() + create_contact_call.execute.return_value = created_contact_data + + people_mock = MagicMock() + people_mock.createContact.return_value = create_contact_call + + service_mock = MagicMock() + service_mock.people.return_value = people_mock + + with patch( + "arcade_google_contacts.tools.contacts.build_people_service", return_value=service_mock + ) as mock_build: + result = await create_contact( + mock_context, + given_name="Alice", + family_name="Smith", + email="alice@example.com", + ) + assert "contact" in result + assert result["contact"] == created_contact_data + + # Verify that the createContact API was called with the correct body contents. + expected_body = { + "names": [{"givenName": "Alice", "familyName": "Smith"}], + "emailAddresses": [{"value": "alice@example.com", "type": "work"}], + } + people_mock.createContact.assert_called_once_with( + body=expected_body, personFields="names,emailAddresses" + ) + mock_build.assert_called_once() + + +@pytest.mark.asyncio +async def test_create_contact_success_without_optional(mock_context): + # Test create_contact without optional parameters family_name and email. + created_contact_data = {"resourceName": "people/456", "etag": "def"} + + create_contact_call = MagicMock() + create_contact_call.execute.return_value = created_contact_data + + people_mock = MagicMock() + people_mock.createContact.return_value = create_contact_call + + service_mock = MagicMock() + service_mock.people.return_value = people_mock + + with patch( + "arcade_google_contacts.tools.contacts.build_people_service", return_value=service_mock + ): + result = await create_contact(mock_context, given_name="Bob", family_name=None, email=None) + assert "contact" in result + assert result["contact"] == created_contact_data + + # Expected body should only include the givenName when family_name and email are omitted. + expected_body = {"names": [{"givenName": "Bob"}]} + people_mock.createContact.assert_called_once_with( + body=expected_body, personFields="names,emailAddresses" + ) + + +@pytest.mark.asyncio +async def test_create_contact_error(mock_context): + # Simulate an error thrown by createContact + error_call = MagicMock() + error_call.execute.side_effect = Exception("Create error") + + people_mock = MagicMock() + people_mock.createContact.return_value = error_call + + service_mock = MagicMock() + service_mock.people.return_value = people_mock + + with ( + patch( + "arcade_google_contacts.tools.contacts.build_people_service", return_value=service_mock + ), + pytest.raises(Exception, match="Error in execution of CreateContact"), + ): + await create_contact(mock_context, given_name="Alice", family_name="Doe", email=None) diff --git a/toolkits/google_docs/.pre-commit-config.yaml b/toolkits/google_docs/.pre-commit-config.yaml new file mode 100644 index 00000000..8e74cd01 --- /dev/null +++ b/toolkits/google_docs/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_docs/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_docs/.ruff.toml b/toolkits/google_docs/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/google_docs/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_docs/Makefile b/toolkits/google_docs/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_docs/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_docs/arcade_google_docs/__init__.py b/toolkits/google_docs/arcade_google_docs/__init__.py new file mode 100644 index 00000000..2ef99351 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/__init__.py @@ -0,0 +1,17 @@ +from arcade_google_docs.tools import ( + create_blank_document, + create_document_from_text, + get_document_by_id, + insert_text_at_end_of_document, + search_and_retrieve_documents, + search_documents, +) + +__all__ = [ + "create_blank_document", + "create_document_from_text", + "get_document_by_id", + "insert_text_at_end_of_document", + "search_and_retrieve_documents", + "search_documents", +] diff --git a/toolkits/google_docs/arcade_google_docs/decorators.py b/toolkits/google_docs/arcade_google_docs/decorators.py new file mode 100644 index 00000000..ffeb6aa7 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/decorators.py @@ -0,0 +1,24 @@ +import functools +from collections.abc import Callable +from typing import Any + +from arcade_tdk import ToolContext +from googleapiclient.errors import HttpError + +from arcade_google_docs.file_picker import generate_google_file_picker_url + + +def with_filepicker_fallback(func: Callable[..., Any]) -> Callable[..., Any]: + """ """ + + @functools.wraps(func) + async def async_wrapper(context: ToolContext, *args: Any, **kwargs: Any) -> Any: + try: + return await func(context, *args, **kwargs) + except HttpError as e: + if e.status_code in [403, 404]: + file_picker_response = generate_google_file_picker_url(context) + return file_picker_response + raise + + return async_wrapper diff --git a/toolkits/google_docs/arcade_google_docs/doc_to_html.py b/toolkits/google_docs/arcade_google_docs/doc_to_html.py new file mode 100644 index 00000000..d54fcef0 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/doc_to_html.py @@ -0,0 +1,99 @@ +def convert_document_to_html(document: dict) -> str: + html = ( + "" + f"{document['title']}" + f'' + "" + ) + for element in document["body"]["content"]: + html += convert_structural_element(element) + html += "" + return html + + +def convert_structural_element(element: dict, wrap_paragraphs: bool = True) -> str: + if "sectionBreak" in element or "tableOfContents" in element: + return "" + + elif "paragraph" in element: + paragraph_content = "" + + prepend, append = get_paragraph_style_tags( + style=element["paragraph"]["paragraphStyle"], + wrap_paragraphs=wrap_paragraphs, + ) + + for item in element["paragraph"]["elements"]: + if "textRun" not in item: + continue + paragraph_content += extract_paragraph_content(item["textRun"]) + + if not paragraph_content: + return "" + + return f"{prepend}{paragraph_content.strip()}{append}" + + elif "table" in element: + table = [ + [ + "".join([ + convert_structural_element(element=cell_element, wrap_paragraphs=False) + for cell_element in cell["content"] + ]) + for cell in row["tableCells"] + ] + for row in element["table"]["tableRows"] + ] + return table_list_to_html(table) + + else: + raise ValueError(f"Unknown document body element type: {element}") + + +def extract_paragraph_content(text_run: dict) -> str: + content = text_run["content"] + style = text_run["textStyle"] + return apply_text_style(content, style) + + +def apply_text_style(content: str, style: dict) -> str: + content = content.rstrip("\n") + content = content.replace("\n", "
") + italic = style.get("italic", False) + bold = style.get("bold", False) + if italic: + content = f"{content}" + if bold: + content = f"{content}" + return content + + +def get_paragraph_style_tags(style: dict, wrap_paragraphs: bool = True) -> tuple[str, str]: + named_style = style["namedStyleType"] + if named_style == "NORMAL_TEXT": + return ("

", "

") if wrap_paragraphs else ("", "") + elif named_style == "TITLE": + return "

", "

" + elif named_style == "SUBTITLE": + return "

", "

" + elif named_style.startswith("HEADING_"): + try: + heading_level = int(named_style.split("_")[1]) + except ValueError: + return ("

", "

") if wrap_paragraphs else ("", "") + else: + return f"", f"" + return ("

", "

") if wrap_paragraphs else ("", "") + + +def table_list_to_html(table: list[list[str]]) -> str: + html = "" + for row in table: + html += "" + for cell in row: + if cell.endswith("
"): + cell = cell[:-4] + html += f"" + html += "" + html += "
{cell}
" + return html diff --git a/toolkits/google_docs/arcade_google_docs/doc_to_markdown.py b/toolkits/google_docs/arcade_google_docs/doc_to_markdown.py new file mode 100644 index 00000000..b7be21f8 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/doc_to_markdown.py @@ -0,0 +1,64 @@ +import arcade_google_docs.doc_to_html as doc_to_html + + +def convert_document_to_markdown(document: dict) -> str: + md = f"---\ntitle: {document['title']}\ndocumentId: {document['documentId']}\n---\n" + for element in document["body"]["content"]: + md += convert_structural_element(element) + return md + + +def convert_structural_element(element: dict) -> str: + if "sectionBreak" in element or "tableOfContents" in element: + return "" + + elif "paragraph" in element: + md = "" + prepend = get_paragraph_style_prepend_str(element["paragraph"]["paragraphStyle"]) + for item in element["paragraph"]["elements"]: + if "textRun" not in item: + continue + content = extract_paragraph_content(item["textRun"]) + md += f"{prepend}{content}" + return md + + elif "table" in element: + return doc_to_html.convert_structural_element(element) + + else: + raise ValueError(f"Unknown document body element type: {element}") + + +def extract_paragraph_content(text_run: dict) -> str: + content = text_run["content"] + style = text_run["textStyle"] + return apply_text_style(content, style) + + +def apply_text_style(content: str, style: dict) -> str: + append = "\n" if content.endswith("\n") else "" + content = content.rstrip("\n") + italic = style.get("italic", False) + bold = style.get("bold", False) + if italic: + content = f"_{content}_" + if bold: + content = f"**{content}**" + return f"{content}{append}" + + +def get_paragraph_style_prepend_str(style: dict) -> str: + named_style = style["namedStyleType"] + if named_style == "NORMAL_TEXT": + return "" + elif named_style == "TITLE": + return "# " + elif named_style == "SUBTITLE": + return "## " + elif named_style.startswith("HEADING_"): + try: + heading_level = int(named_style.split("_")[1]) + return f"{'#' * heading_level} " + except ValueError: + return "" + return "" diff --git a/toolkits/google_docs/arcade_google_docs/enum.py b/toolkits/google_docs/arcade_google_docs/enum.py new file mode 100644 index 00000000..50202e79 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/enum.py @@ -0,0 +1,116 @@ +from enum import Enum + + +class Corpora(str, Enum): + """ + Bodies of items (files/documents) to which the query applies. + Prefer 'user' or 'drive' to 'allDrives' for efficiency. + By default, corpora is set to 'user'. + """ + + USER = "user" + DOMAIN = "domain" + DRIVE = "drive" + ALL_DRIVES = "allDrives" + + +class DocumentFormat(str, Enum): + MARKDOWN = "markdown" + HTML = "html" + GOOGLE_API_JSON = "google_api_json" + + +class OrderBy(str, Enum): + """ + Sort keys for ordering files in Google Drive. + Each key has both ascending and descending options. + """ + + CREATED_TIME = ( + # When the file was created (ascending) + "createdTime" + ) + CREATED_TIME_DESC = ( + # When the file was created (descending) + "createdTime desc" + ) + FOLDER = ( + # The folder ID, sorted using alphabetical ordering (ascending) + "folder" + ) + FOLDER_DESC = ( + # The folder ID, sorted using alphabetical ordering (descending) + "folder desc" + ) + MODIFIED_BY_ME_TIME = ( + # The last time the file was modified by the user (ascending) + "modifiedByMeTime" + ) + MODIFIED_BY_ME_TIME_DESC = ( + # The last time the file was modified by the user (descending) + "modifiedByMeTime desc" + ) + MODIFIED_TIME = ( + # The last time the file was modified by anyone (ascending) + "modifiedTime" + ) + MODIFIED_TIME_DESC = ( + # The last time the file was modified by anyone (descending) + "modifiedTime desc" + ) + NAME = ( + # The name of the file, sorted using alphabetical ordering (e.g., 1, 12, 2, 22) (ascending) + "name" + ) + NAME_DESC = ( + # The name of the file, sorted using alphabetical ordering (e.g., 1, 12, 2, 22) (descending) + "name desc" + ) + NAME_NATURAL = ( + # The name of the file, sorted using natural sort ordering (e.g., 1, 2, 12, 22) (ascending) + "name_natural" + ) + NAME_NATURAL_DESC = ( + # The name of the file, sorted using natural sort ordering (e.g., 1, 2, 12, 22) (descending) + "name_natural desc" + ) + QUOTA_BYTES_USED = ( + # The number of storage quota bytes used by the file (ascending) + "quotaBytesUsed" + ) + QUOTA_BYTES_USED_DESC = ( + # The number of storage quota bytes used by the file (descending) + "quotaBytesUsed desc" + ) + RECENCY = ( + # The most recent timestamp from the file's date-time fields (ascending) + "recency" + ) + RECENCY_DESC = ( + # The most recent timestamp from the file's date-time fields (descending) + "recency desc" + ) + SHARED_WITH_ME_TIME = ( + # When the file was shared with the user, if applicable (ascending) + "sharedWithMeTime" + ) + SHARED_WITH_ME_TIME_DESC = ( + # When the file was shared with the user, if applicable (descending) + "sharedWithMeTime desc" + ) + STARRED = ( + # Whether the user has starred the file (ascending) + "starred" + ) + STARRED_DESC = ( + # Whether the user has starred the file (descending) + "starred desc" + ) + VIEWED_BY_ME_TIME = ( + # The last time the file was viewed by the user (ascending) + "viewedByMeTime" + ) + VIEWED_BY_ME_TIME_DESC = ( + # The last time the file was viewed by the user (descending) + "viewedByMeTime desc" + ) diff --git a/toolkits/google_docs/arcade_google_docs/file_picker.py b/toolkits/google_docs/arcade_google_docs/file_picker.py new file mode 100644 index 00000000..193690ef --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/file_picker.py @@ -0,0 +1,49 @@ +import base64 +import json + +from arcade_tdk import ToolContext, ToolMetadataKey +from arcade_tdk.errors import ToolExecutionError + + +def generate_google_file_picker_url(context: ToolContext) -> dict: + """Generate a Google File Picker URL for user-driven file selection and authorization. + + Generates a URL that directs the end-user to a Google File Picker interface where + where they can select or upload Google Drive files. Users can grant permission to access their + Drive files, providing a secure and authorized way to interact with their files. + + This is particularly useful when prior tools (e.g., those accessing or modifying + Google Docs, Google Sheets, etc.) encountered failures due to file non-existence + (Requested entity was not found) or permission errors. Once the user completes the file + picker flow, the prior tool can be retried. + + Returns: + A dictionary containing the URL and instructions for the llm to instruct the user. + """ + client_id = context.get_metadata(ToolMetadataKey.CLIENT_ID) + client_id_parts = client_id.split("-") + if not client_id_parts: + raise ToolExecutionError( + message="Invalid Google Client ID", + developer_message=f"Google Client ID '{client_id}' is not valid", + ) + app_id = client_id_parts[0] + cloud_coordinator_url = context.get_metadata(ToolMetadataKey.COORDINATOR_URL).strip("/") + + config = { + "auth": { + "client_id": client_id, + "app_id": app_id, + }, + } + config_json = json.dumps(config) + config_base64 = base64.urlsafe_b64encode(config_json.encode("utf-8")).decode("utf-8") + url = f"{cloud_coordinator_url}/google/drive_picker?config={config_base64}" + + return { + "url": url, + "llm_instructions": ( + "Instruct the user to click the following link to open the Google Drive File Picker. " + f"This will allow them to select files and grant access permissions: {url}" + ), + } diff --git a/toolkits/google_docs/arcade_google_docs/templates.py b/toolkits/google_docs/arcade_google_docs/templates.py new file mode 100644 index 00000000..97d9d2d1 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/templates.py @@ -0,0 +1,5 @@ +optional_file_picker_instructions_template = ( + "Ensure the user knows that they have the option to select and grant access permissions to " + "additional documents via the Google Drive File Picker. " + "The user can pick additional documents via the following link: {url}" +) diff --git a/toolkits/google_docs/arcade_google_docs/tools/__init__.py b/toolkits/google_docs/arcade_google_docs/tools/__init__.py new file mode 100644 index 00000000..48437ffa --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/tools/__init__.py @@ -0,0 +1,19 @@ +from arcade_google_docs.tools.create import ( + create_blank_document, + create_document_from_text, +) +from arcade_google_docs.tools.get import get_document_by_id +from arcade_google_docs.tools.search import ( + search_and_retrieve_documents, + search_documents, +) +from arcade_google_docs.tools.update import insert_text_at_end_of_document + +__all__ = [ + "create_blank_document", + "create_document_from_text", + "get_document_by_id", + "insert_text_at_end_of_document", + "search_and_retrieve_documents", + "search_documents", +] diff --git a/toolkits/google_docs/arcade_google_docs/tools/create.py b/toolkits/google_docs/arcade_google_docs/tools/create.py new file mode 100644 index 00000000..22a6c010 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/tools/create.py @@ -0,0 +1,82 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Google + +from arcade_google_docs.utils import build_docs_service + + +# Uses https://developers.google.com/docs/api/reference/rest/v1/documents/create +# Example `arcade chat` query: `create blank document with title "My New Document"` +@tool( + requires_auth=Google( + scopes=[ + "https://www.googleapis.com/auth/drive.file", + ], + ) +) +async def create_blank_document( + context: ToolContext, title: Annotated[str, "The title of the blank document to create"] +) -> Annotated[dict, "The created document's title, documentId, and documentUrl in a dictionary"]: + """ + Create a blank Google Docs document with the specified title. + """ + service = build_docs_service(context.get_auth_token_or_empty()) + + body = {"title": title} + + # Execute the documents().create() method. Returns a Document object https://developers.google.com/docs/api/reference/rest/v1/documents#Document + request = service.documents().create(body=body) + response = request.execute() + + return { + "title": response["title"], + "documentId": response["documentId"], + "documentUrl": f"https://docs.google.com/document/d/{response['documentId']}/edit", + } + + +# Uses https://developers.google.com/docs/api/reference/rest/v1/documents/batchUpdate +# Example `arcade chat` query: +# `create document with title "My New Document" and text content "Hello, World!"` +@tool( + requires_auth=Google( + scopes=[ + "https://www.googleapis.com/auth/drive.file", + ], + ) +) +async def create_document_from_text( + context: ToolContext, + title: Annotated[str, "The title of the document to create"], + text_content: Annotated[str, "The text content to insert into the document"], +) -> Annotated[dict, "The created document's title, documentId, and documentUrl in a dictionary"]: + """ + Create a Google Docs document with the specified title and text content. + """ + # First, create a blank document + document = await create_blank_document(context, title) + + service = build_docs_service(context.get_auth_token_or_empty()) + + requests = [ + { + "insertText": { + "location": { + "index": 1, + }, + "text": text_content, + } + } + ] + + # Execute the batchUpdate method to insert text + service.documents().batchUpdate( + documentId=document["documentId"], body={"requests": requests} + ).execute() + + return { + "title": document["title"], + "documentId": document["documentId"], + "documentUrl": f"https://docs.google.com/document/d/{document['documentId']}/edit", + } diff --git a/toolkits/google_docs/arcade_google_docs/tools/get.py b/toolkits/google_docs/arcade_google_docs/tools/get.py new file mode 100644 index 00000000..72027aea --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/tools/get.py @@ -0,0 +1,35 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, ToolMetadataKey, tool +from arcade_tdk.auth import Google + +from arcade_google_docs.decorators import with_filepicker_fallback +from arcade_google_docs.utils import build_docs_service + + +# Uses https://developers.google.com/docs/api/reference/rest/v1/documents/get +# Example `arcade chat` query: `get document with ID 1234567890` +# Note: Document IDs are returned in the response of the Google Drive's `list_documents` tool +@tool( + requires_auth=Google( + scopes=[ + "https://www.googleapis.com/auth/drive.file", + ], + ), + requires_metadata=[ToolMetadataKey.CLIENT_ID, ToolMetadataKey.COORDINATOR_URL], +) +@with_filepicker_fallback +async def get_document_by_id( + context: ToolContext, + document_id: Annotated[str, "The ID of the document to retrieve."], +) -> Annotated[dict, "The document contents as a dictionary"]: + """ + Get the latest version of the specified Google Docs document. + """ + service = build_docs_service(context.get_auth_token_or_empty()) + + # Execute the documents().get() method. Returns a Document object + # https://developers.google.com/docs/api/reference/rest/v1/documents#Document + request = service.documents().get(documentId=document_id) + response = request.execute() + return dict(response) diff --git a/toolkits/google_docs/arcade_google_docs/tools/search.py b/toolkits/google_docs/arcade_google_docs/tools/search.py new file mode 100644 index 00000000..4221cdf1 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/tools/search.py @@ -0,0 +1,219 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, ToolMetadataKey, tool +from arcade_tdk.auth import Google + +from arcade_google_docs.doc_to_html import convert_document_to_html +from arcade_google_docs.doc_to_markdown import convert_document_to_markdown +from arcade_google_docs.enum import DocumentFormat, OrderBy +from arcade_google_docs.file_picker import generate_google_file_picker_url +from arcade_google_docs.templates import optional_file_picker_instructions_template +from arcade_google_docs.tools import get_document_by_id +from arcade_google_docs.utils import ( + build_drive_service, + build_files_list_params, +) + + +# Implements: https://googleapis.github.io/google-api-python-client/docs/dyn/drive_v3.files.html#list +# Example `arcade chat` query: `list my 5 most recently modified documents` +# TODO: Support query with natural language. Currently, the tool expects a fully formed query +# string as input with the syntax defined here: https://developers.google.com/drive/api/guides/search-files +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/drive.file"], + ), + requires_metadata=[ToolMetadataKey.CLIENT_ID, ToolMetadataKey.COORDINATOR_URL], +) +async def search_documents( + context: ToolContext, + document_contains: Annotated[ + list[str] | None, + "Keywords or phrases that must be in the document title or body. Provide a list of " + "keywords or phrases if needed.", + ] = None, + document_not_contains: Annotated[ + list[str] | None, + "Keywords or phrases that must NOT be in the document title or body. Provide a list of " + "keywords or phrases if needed.", + ] = None, + search_only_in_shared_drive_id: Annotated[ + str | None, + "The ID of the shared drive to restrict the search to. If provided, the search will only " + "return documents from this drive. Defaults to None, which searches across all drives.", + ] = None, + include_shared_drives: Annotated[ + bool, + "Whether to include documents from shared drives. Defaults to False (searches only in " + "the user's 'My Drive').", + ] = False, + include_organization_domain_documents: Annotated[ + bool, + "Whether to include documents from the organization's domain. This is applicable to admin " + "users who have permissions to view organization-wide documents in a Google Workspace " + "account. Defaults to False.", + ] = False, + order_by: Annotated[ + list[OrderBy] | None, + "Sort order. Defaults to listing the most recently modified documents first", + ] = None, + limit: Annotated[int, "The number of documents to list"] = 50, + pagination_token: Annotated[ + str | None, "The pagination token to continue a previous request" + ] = None, +) -> Annotated[ + dict, + "A dictionary containing 'documents_count' (number of documents returned) and 'documents' " + "(a list of document details including 'kind', 'mimeType', 'id', and 'name' for each document)", +]: + """ + Searches for documents in the user's Google Drive. Excludes documents that are in the trash. + """ + if order_by is None: + order_by = [OrderBy.MODIFIED_TIME_DESC] + elif isinstance(order_by, OrderBy): + order_by = [order_by] + + page_size = min(10, limit) + files: list[dict[str, Any]] = [] + + service = build_drive_service(context.get_auth_token_or_empty()) + + params = build_files_list_params( + mime_type="application/vnd.google-apps.document", + document_contains=document_contains, + document_not_contains=document_not_contains, + page_size=page_size, + order_by=order_by, + pagination_token=pagination_token, + include_shared_drives=include_shared_drives, + search_only_in_shared_drive_id=search_only_in_shared_drive_id, + include_organization_domain_documents=include_organization_domain_documents, + ) + + while len(files) < limit: + if pagination_token: + params["pageToken"] = pagination_token + else: + params.pop("pageToken", None) + + results = service.files().list(**params).execute() + batch = results.get("files", []) + files.extend(batch[: limit - len(files)]) + + pagination_token = results.get("nextPageToken") + if not pagination_token or len(batch) < page_size: + break + + file_picker_response = generate_google_file_picker_url( + context, + ) + + return { + "documents_count": len(files), + "documents": files, + "file_picker": { + "url": file_picker_response["url"], + "llm_instructions": optional_file_picker_instructions_template.format( + url=file_picker_response["url"] + ), + }, + } + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/drive.file"], + ), + requires_metadata=[ToolMetadataKey.CLIENT_ID, ToolMetadataKey.COORDINATOR_URL], +) +async def search_and_retrieve_documents( + context: ToolContext, + return_format: Annotated[ + DocumentFormat, + "The format of the document to return. Defaults to Markdown.", + ] = DocumentFormat.MARKDOWN, + document_contains: Annotated[ + list[str] | None, + "Keywords or phrases that must be in the document title or body. Provide a list of " + "keywords or phrases if needed.", + ] = None, + document_not_contains: Annotated[ + list[str] | None, + "Keywords or phrases that must NOT be in the document title or body. Provide a list of " + "keywords or phrases if needed.", + ] = None, + search_only_in_shared_drive_id: Annotated[ + str | None, + "The ID of the shared drive to restrict the search to. If provided, the search will only " + "return documents from this drive. Defaults to None, which searches across all drives.", + ] = None, + include_shared_drives: Annotated[ + bool, + "Whether to include documents from shared drives. Defaults to False (searches only in " + "the user's 'My Drive').", + ] = False, + include_organization_domain_documents: Annotated[ + bool, + "Whether to include documents from the organization's domain. This is applicable to admin " + "users who have permissions to view organization-wide documents in a Google Workspace " + "account. Defaults to False.", + ] = False, + order_by: Annotated[ + list[OrderBy] | None, + "Sort order. Defaults to listing the most recently modified documents first", + ] = None, + limit: Annotated[int, "The number of documents to list"] = 50, + pagination_token: Annotated[ + str | None, "The pagination token to continue a previous request" + ] = None, +) -> Annotated[ + dict, + "A dictionary containing 'documents_count' (number of documents returned) and 'documents' " + "(a list of documents with their content).", +]: + """ + Searches for documents in the user's Google Drive and returns a list of documents (with text + content) matching the search criteria. Excludes documents that are in the trash. + + Note: use this tool only when the user prompt requires the documents' content. If the user only + needs a list of documents, use the `search_documents` tool instead. + """ + response = await search_documents( + context=context, + document_contains=document_contains, + document_not_contains=document_not_contains, + search_only_in_shared_drive_id=search_only_in_shared_drive_id, + include_shared_drives=include_shared_drives, + include_organization_domain_documents=include_organization_domain_documents, + order_by=order_by, + limit=limit, + pagination_token=pagination_token, + ) + + documents = [] + + for item in response["documents"]: + document = await get_document_by_id(context, document_id=item["id"]) + + if return_format == DocumentFormat.MARKDOWN: + document = convert_document_to_markdown(document) + elif return_format == DocumentFormat.HTML: + document = convert_document_to_html(document) + + documents.append(document) + + file_picker_response = generate_google_file_picker_url( + context, + ) + + return { + "documents_count": len(documents), + "documents": documents, + "file_picker": { + "url": file_picker_response["url"], + "llm_instructions": optional_file_picker_instructions_template.format( + url=file_picker_response["url"] + ), + }, + } diff --git a/toolkits/google_docs/arcade_google_docs/tools/update.py b/toolkits/google_docs/arcade_google_docs/tools/update.py new file mode 100644 index 00000000..1c3d1714 --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/tools/update.py @@ -0,0 +1,60 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, ToolMetadataKey, tool +from arcade_tdk.auth import Google + +from arcade_google_docs.decorators import with_filepicker_fallback +from arcade_google_docs.tools.get import get_document_by_id +from arcade_google_docs.utils import build_docs_service + + +# Uses https://developers.google.com/docs/api/reference/rest/v1/documents/batchUpdate +# Example `arcade chat` query: `insert "The END" at the end of document with ID 1234567890` +@tool( + requires_auth=Google( + scopes=[ + "https://www.googleapis.com/auth/drive.file", + ], + ), + requires_metadata=[ToolMetadataKey.CLIENT_ID, ToolMetadataKey.COORDINATOR_URL], +) +@with_filepicker_fallback +async def insert_text_at_end_of_document( + context: ToolContext, + document_id: Annotated[str, "The ID of the document to update."], + text_content: Annotated[str, "The text content to insert into the document"], +) -> Annotated[dict, "The response from the batchUpdate API as a dict."]: + """ + Updates an existing Google Docs document using the batchUpdate API endpoint. + """ + document_or_file_picker_response = await get_document_by_id(context, document_id) + + # If the document was not found, return the file picker response + if "body" not in document_or_file_picker_response: + return document_or_file_picker_response # type: ignore[no-any-return] + + document = document_or_file_picker_response + + end_index = document["body"]["content"][-1]["endIndex"] + + service = build_docs_service(context.get_auth_token_or_empty()) + + requests = [ + { + "insertText": { + "location": { + "index": int(end_index) - 1, + }, + "text": text_content, + } + } + ] + + # Execute the documents().batchUpdate() method + response = ( + service.documents() + .batchUpdate(documentId=document_id, body={"requests": requests}) + .execute() + ) + + return dict(response) diff --git a/toolkits/google_docs/arcade_google_docs/utils.py b/toolkits/google_docs/arcade_google_docs/utils.py new file mode 100644 index 00000000..6780b7ae --- /dev/null +++ b/toolkits/google_docs/arcade_google_docs/utils.py @@ -0,0 +1,119 @@ +import logging +from typing import Any + +from google.oauth2.credentials import Credentials +from googleapiclient.discovery import Resource, build + +from arcade_google_docs.enum import Corpora, OrderBy + +## Set up basic configuration for logging to the console with DEBUG level and a specific format. +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +def build_docs_service(auth_token: str | None) -> Resource: # type: ignore[no-any-unimported] + """ + Build a Drive service object. + """ + auth_token = auth_token or "" + return build("docs", "v1", credentials=Credentials(auth_token)) + + +def build_drive_service(auth_token: str | None) -> Resource: # type: ignore[no-any-unimported] + """ + Build a Drive service object. + """ + auth_token = auth_token or "" + return build("drive", "v3", credentials=Credentials(auth_token)) + + +def build_files_list_params( + mime_type: str, + page_size: int, + order_by: list[OrderBy], + pagination_token: str | None, + include_shared_drives: bool, + search_only_in_shared_drive_id: str | None, + include_organization_domain_documents: bool, + document_contains: list[str] | None = None, + document_not_contains: list[str] | None = None, +) -> dict[str, Any]: + query = build_files_list_query( + mime_type=mime_type, + document_contains=document_contains, + document_not_contains=document_not_contains, + ) + + params = { + "q": query, + "pageSize": page_size, + "orderBy": ",".join([item.value for item in order_by]), + "pageToken": pagination_token, + } + + if ( + include_shared_drives + or search_only_in_shared_drive_id + or include_organization_domain_documents + ): + params["includeItemsFromAllDrives"] = "true" + params["supportsAllDrives"] = "true" + + if search_only_in_shared_drive_id: + params["driveId"] = search_only_in_shared_drive_id + params["corpora"] = Corpora.DRIVE.value + + if include_organization_domain_documents: + params["corpora"] = Corpora.DOMAIN.value + + params = remove_none_values(params) + + return params + + +def build_files_list_query( + mime_type: str, + document_contains: list[str] | None = None, + document_not_contains: list[str] | None = None, +) -> str: + query = [f"(mimeType = '{mime_type}' and trashed = false)"] + + if isinstance(document_contains, str): + document_contains = [document_contains] + + if isinstance(document_not_contains, str): + document_not_contains = [document_not_contains] + + if document_contains: + for keyword in document_contains: + name_contains = keyword.replace("'", "\\'") + full_text_contains = keyword.replace("'", "\\'") + keyword_query = ( + f"(name contains '{name_contains}' or fullText contains '{full_text_contains}')" + ) + query.append(keyword_query) + + if document_not_contains: + for keyword in document_not_contains: + name_not_contains = keyword.replace("'", "\\'") + full_text_not_contains = keyword.replace("'", "\\'") + keyword_query = ( + f"(name not contains '{name_not_contains}' and " + f"fullText not contains '{full_text_not_contains}')" + ) + query.append(keyword_query) + + return " and ".join(query) + + +def remove_none_values(params: dict) -> dict: + """ + Remove None values from a dictionary. + :param params: The dictionary to clean + :return: A new dictionary with None values removed + """ + return {k: v for k, v in params.items() if v is not None} diff --git a/toolkits/google_docs/conftest.py b/toolkits/google_docs/conftest.py new file mode 100644 index 00000000..ef47c5b5 --- /dev/null +++ b/toolkits/google_docs/conftest.py @@ -0,0 +1,967 @@ +import pytest + + +@pytest.fixture +def sample_document_and_expected_formats(): + document = { + "title": "The Birth of Machine Experience Engineering", + "documentId": "1234567890", + "body": { + "content": [ + { + "endIndex": 1, + "sectionBreak": { + "sectionStyle": { + "columnSeparatorStyle": "NONE", + "contentDirection": "LEFT_TO_RIGHT", + "sectionType": "CONTINUOUS", + } + }, + }, + { + "startIndex": 1, + "endIndex": 45, + "paragraph": { + "elements": [ + { + "endIndex": 45, + "startIndex": 1, + "textRun": { + "content": "The Birth of Machine Experience Engineering\n", + "textStyle": { + "bold": True, + "fontSize": {"magnitude": 23, "unit": "PT"}, + }, + }, + } + ], + "paragraphStyle": { + "direction": "LEFT_TO_RIGHT", + "headingId": "h.wwd7ec37bh6k", + "keepLinesTogether": False, + "keepWithNext": False, + "namedStyleType": "HEADING_1", + "spaceAbove": {"magnitude": 24, "unit": "PT"}, + }, + }, + }, + { + "startIndex": 45, + "endIndex": 46, + "paragraph": { + "elements": [ + { + "startIndex": 304, + "endIndex": 305, + "inlineObjectElement": { + "inlineObjectId": "kix.2s5wy5oiaf79", + "textStyle": {}, + }, + }, + { + "endIndex": 46, + "startIndex": 45, + "textRun": {"content": "\n", "textStyle": {}}, + }, + ], + "paragraphStyle": { + "direction": "LEFT_TO_RIGHT", + "namedStyleType": "NORMAL_TEXT", + "spaceAbove": {"magnitude": 12, "unit": "PT"}, + "spaceBelow": {"magnitude": 12, "unit": "PT"}, + }, + }, + }, + { + "startIndex": 46, + "endIndex": 297, + "paragraph": { + "elements": [ + { + "startIndex": 46, + "endIndex": 146, + "textRun": { + "content": ( + "LLMs acting on behalf of humans and interacting with real-" + "world systems isn't theoretical anymore - " + ), + "textStyle": {}, + }, + }, + { + "startIndex": 146, + "endIndex": 175, + "textRun": { + "content": "Arcade has made it a reality.", + "textStyle": { + "bold": True, + "italic": True, + }, + }, + }, + { + "startIndex": 175, + "endIndex": 248, + "textRun": { + "content": ( + " With this shift, we're seeing the emergence of a new " + "software practice: " + ), + "textStyle": {}, + }, + }, + { + "startIndex": 248, + "endIndex": 295, + "textRun": { + "content": "Machine Experience Engineering (MX Engineering)", + "textStyle": { + "italic": True, + }, + }, + }, + { + "startIndex": 295, + "endIndex": 297, + "textRun": { + "content": ".\n", + "textStyle": {}, + }, + }, + ], + "paragraphStyle": { + "direction": "LEFT_TO_RIGHT", + "namedStyleType": "NORMAL_TEXT", + "spaceAbove": {"magnitude": 12, "unit": "PT"}, + "spaceBelow": {"magnitude": 12, "unit": "PT"}, + }, + }, + }, + { + "endIndex": 407, + "startIndex": 297, + "table": { + "columns": 3, + "rows": 3, + "tableRows": [ + { + "endIndex": 338, + "startIndex": 297, + "tableCells": [ + { + "content": [ + { + "endIndex": 318, + "paragraph": { + "elements": [ + { + "endIndex": 318, + "startIndex": 309, + "textRun": { + "content": "Column 1\n", + "textStyle": {"bold": True}, + }, + } + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 309, + } + ], + "endIndex": 318, + "startIndex": 308, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + { + "content": [ + { + "endIndex": 334, + "paragraph": { + "elements": [ + { + "endIndex": 326, + "startIndex": 319, + "textRun": { + "content": "Another", + "textStyle": {"italic": True}, + }, + }, + { + "endIndex": 334, + "startIndex": 326, + "textRun": { + "content": " column\n", + "textStyle": {}, + }, + }, + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 319, + } + ], + "endIndex": 334, + "startIndex": 318, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + { + "content": [ + { + "endIndex": 348, + "paragraph": { + "elements": [ + { + "endIndex": 348, + "startIndex": 335, + "textRun": { + "content": "Third column\n", + "textStyle": {}, + }, + } + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 335, + } + ], + "endIndex": 348, + "startIndex": 334, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + ], + "tableRowStyle": {"minRowHeight": {"unit": "PT"}}, + }, + { + "endIndex": 366, + "startIndex": 348, + "tableCells": [ + { + "content": [ + { + "endIndex": 356, + "paragraph": { + "elements": [ + { + "endIndex": 356, + "startIndex": 350, + "textRun": { + "content": "Hello\n", + "textStyle": {}, + }, + } + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 350, + } + ], + "endIndex": 356, + "startIndex": 349, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + { + "content": [ + { + "endIndex": 364, + "paragraph": { + "elements": [ + { + "endIndex": 364, + "startIndex": 357, + "textRun": { + "content": "world!\n", + "textStyle": {}, + }, + } + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 357, + } + ], + "endIndex": 364, + "startIndex": 356, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + { + "content": [ + { + "endIndex": 366, + "paragraph": { + "elements": [ + { + "endIndex": 366, + "startIndex": 365, + "textRun": { + "content": "\n", + "textStyle": {}, + }, + } + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 365, + } + ], + "endIndex": 366, + "startIndex": 364, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + ], + "tableRowStyle": {"minRowHeight": {"unit": "PT"}}, + }, + { + "endIndex": 415, + "startIndex": 366, + "tableCells": [ + { + "content": [ + { + "endIndex": 388, + "paragraph": { + "elements": [ + { + "endIndex": 388, + "startIndex": 368, + "textRun": { + "content": "The quick brown fox\n", + "textStyle": {}, + }, + } + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 368, + } + ], + "endIndex": 388, + "startIndex": 367, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + { + "content": [ + { + "endIndex": 401, + "paragraph": { + "elements": [ + { + "endIndex": 395, + "startIndex": 389, + "textRun": { + "content": "jumped", + "textStyle": {"italic": True}, + }, + }, + { + "endIndex": 401, + "startIndex": 395, + "textRun": { + "content": " over\n", + "textStyle": {}, + }, + }, + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 389, + } + ], + "endIndex": 401, + "startIndex": 388, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + { + "content": [ + { + "endIndex": 415, + "paragraph": { + "elements": [ + { + "endIndex": 415, + "startIndex": 402, + "textRun": { + "content": "the lazy dog\n", + "textStyle": {}, + }, + } + ], + "paragraphStyle": { + "alignment": "START", + "avoidWidowAndOrphan": False, + "borderBetween": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderBottom": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderLeft": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderRight": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "borderTop": { + "color": {}, + "dashStyle": "SOLID", + "padding": {"unit": "PT"}, + "width": {"unit": "PT"}, + }, + "direction": "LEFT_TO_RIGHT", + "indentEnd": {"unit": "PT"}, + "indentFirstLine": {"unit": "PT"}, + "indentStart": {"unit": "PT"}, + "keepLinesTogether": False, + "keepWithNext": False, + "lineSpacing": 100, + "namedStyleType": "NORMAL_TEXT", + "pageBreakBefore": False, + "shading": {"backgroundColor": {}}, + "spaceAbove": {"unit": "PT"}, + "spaceBelow": {"unit": "PT"}, + "spacingMode": "COLLAPSE_LISTS", + }, + }, + "startIndex": 402, + } + ], + "endIndex": 415, + "startIndex": 401, + "tableCellStyle": { + "backgroundColor": {}, + "columnSpan": 1, + "contentAlignment": "TOP", + "paddingBottom": {"magnitude": 5, "unit": "PT"}, + "paddingLeft": {"magnitude": 5, "unit": "PT"}, + "paddingRight": {"magnitude": 5, "unit": "PT"}, + "paddingTop": {"magnitude": 5, "unit": "PT"}, + "rowSpan": 1, + }, + }, + ], + "tableRowStyle": {"minRowHeight": {"unit": "PT"}}, + }, + ], + "tableStyle": { + "tableColumnProperties": [ + {"widthType": "EVENLY_DISTRIBUTED"}, + {"widthType": "EVENLY_DISTRIBUTED"}, + {"widthType": "EVENLY_DISTRIBUTED"}, + ] + }, + }, + }, + ] + }, + } + + expected_markdown = ( + "---\ntitle: The Birth of Machine Experience Engineering\ndocumentId: 1234567890\n---\n" + "# **The Birth of Machine Experience Engineering**\n" + "\n" + "LLMs acting on behalf of humans and interacting with real-world systems isn't theoretical " + "anymore - " + "**_Arcade has made it a reality._** With this shift, we're seeing the emergence of a new " + "software practice: " + "_Machine Experience Engineering (MX Engineering)_.\n" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "
Column 1Another columnThird column
Helloworld!
The quick brown foxjumped overthe lazy dog
" + ) + + expected_html = ( + "" + "The Birth of Machine Experience Engineering" + '' + "" + "

The Birth of Machine Experience Engineering

" + "

LLMs acting on behalf of humans and interacting with real-world systems isn't " + "theoretical anymore - " + "Arcade has made it a reality. With this shift, we're seeing the emergence " + "of a new software practice: Machine Experience Engineering (MX Engineering).

" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + "
Column 1Another columnThird column
Helloworld!
The quick brown foxjumped overthe lazy dog
" + "" + ) + + return document, expected_markdown, expected_html diff --git a/toolkits/google_docs/evals/eval_google_docs.py b/toolkits/google_docs/evals/eval_google_docs.py new file mode 100644 index 00000000..834eb904 --- /dev/null +++ b/toolkits/google_docs/evals/eval_google_docs.py @@ -0,0 +1,384 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_google_docs +from arcade_google_docs.enum import DocumentFormat, OrderBy +from arcade_google_docs.tools import ( + create_blank_document, + create_document_from_text, + get_document_by_id, + insert_text_at_end_of_document, + search_and_retrieve_documents, + search_documents, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_module(arcade_google_docs) + + +@tool_eval() +def docs_eval_suite() -> EvalSuite: + """Create an evaluation suite for Google Docs tools.""" + suite = EvalSuite( + name="Google Docs Tools Evaluation", + system_message="You are an AI assistant that can create and manage Google Docs using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + # A previous tool call to list_documents + additional_messages = [ + {"role": "user", "content": "list my 10 most recently created docs"}, + { + "role": "assistant", + "content": "Please go to this URL and authorize the action: [Link](https://accounts.google.com/)", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_gegK723W2hXsORjBmq1Oexqk", + "type": "function", + "function": { + "name": "Google_ListDocuments", + "arguments": '{"limit":10,"order_by":"createdTime desc"}', + }, + } + ], + }, + { + "role": "tool", + "content": '{"documents":[{"id":"1e0rCoT1Yd14WuuEvd3hSUcN_-VD3df4T3Q08uLm3TWc","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst10"},{"id":"1eTSWd-5zQds8K9OWYygwtCFMUyuuMize3bh3HaRsKts","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst9"},{"id":"19Dqugn0rVi89K0C__lpg1HbhQOTenccyZOhPgivTHMs","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst8"},{"id":"1RCibzx14eqP3vS9yI4nD13OKf8Vee56RiszS53OkR7I","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst7"},{"id":"1imFb04JQuBn8SiSsRFf6fEuYCyXkbII4KX8fsmnT0jo","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst6"},{"id":"1ZC3oypdfLWFgBd-emeSykJf9tZOae6USsFboygRCr-w","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst5"},{"id":"1-gFGNWmwLxEiKa6NNixLNq3X-phXRMORVZfVTfBg8Sc","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst4"},{"id":"1eQ8UBO_PY3Lem4R8OVdIc9ODXt0MrSUAnEu994Qz8P8","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst3"},{"id":"1TOxB0MLry-JzntDWDT1LFywTLdr3XDWPT5L5UsHMs5c","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst2"},{"id":"1a1UQ7C90s8kGfnO8k6wfAZz_Cy5nGN2MkCoRB5y2j3w","kind":"drive#file","mimeType":"application/vnd.google-apps.document","name":"Tst1"}],"documents_count":10}', + "tool_call_id": "call_gegK723W2hXsORjBmq1Oexqk", + "name": "Google_ListDocuments", + }, + { + "role": "assistant", + "content": "Here are your 10 most recently created Google Docs:\n\n1. [Tst10](https://docs.google.com/document/d/1e0rCoT1Yd14WuuEvd3hSUcN_-VD3df4T3Q08uLm3TWc)\n2. [Tst9](https://docs.google.com/document/d/1eTSWd-5zQds8K9OWYygwtCFMUyuuMize3bh3HaRsKts)\n3. [Tst8](https://docs.google.com/document/d/19Dqugn0rVi89K0C__lpg1HbhQOTenccyZOhPgivTHMs)\n4. [Tst7](https://docs.google.com/document/d/1RCibzx14eqP3vS9yI4nD13OKf8Vee56RiszS53OkR7I)\n5. [Tst6](https://docs.google.com/document/d/1imFb04JQuBn8SiSsRFf6fEuYCyXkbII4KX8fsmnT0jo)\n6. [Tst5](https://docs.google.com/document/d/1ZC3oypdfLWFgBd-emeSykJf9tZOae6USsFboygRCr-w)\n7. [Tst4](https://docs.google.com/document/d/1-gFGNWmwLxEiKa6NNixLNq3X-phXRMORVZfVTfBg8Sc)\n8. [Tst3](https://docs.google.com/document/d/1eQ8UBO_PY3Lem4R8OVdIc9ODXt0MrSUAnEu994Qz8P8)\n9. [Tst2](https://docs.google.com/document/d/1TOxB0MLry-JzntDWDT1LFywTLdr3XDWPT5L5UsHMs5c)\n10. [Tst1](https://docs.google.com/document/d/1a1UQ7C90s8kGfnO8k6wfAZz_Cy5nGN2MkCoRB5y2j3w)\n\nYou can click the links to open each document.", + }, + ] + + suite.add_case( + name="Get document content", + user_message="Can you read me the contents of Tst9 doc and also Tst10 doc please", + expected_tool_calls=[ + ExpectedToolCall( + func=get_document_by_id, + args={ + "document_id": "1eTSWd-5zQds8K9OWYygwtCFMUyuuMize3bh3HaRsKts", + }, + ), + ExpectedToolCall( + func=get_document_by_id, + args={ + "document_id": "1e0rCoT1Yd14WuuEvd3hSUcN_-VD3df4T3Q08uLm3TWc", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="document_id", weight=0.6), + ], + additional_messages=additional_messages, + ) + + suite.add_case( + name="Insert text at end of document", + user_message="Please add the text 'This is a new paragraph.' to the end of Tst4.", + expected_tool_calls=[ + ExpectedToolCall( + func=insert_text_at_end_of_document, + args={ + "document_id": "1-gFGNWmwLxEiKa6NNixLNq3X-phXRMORVZfVTfBg8Sc", + "text_content": "This is a new paragraph.", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="document_id", weight=0.5), + SimilarityCritic(critic_field="text_content", weight=0.5), + ], + additional_messages=additional_messages, + ) + + suite.add_case( + name="Read the contents of two documents and then insert text at end of a different document.", + user_message="Can you read me the contents of Tst9 doc and also Tst10 doc please. Also, please add the text 'This is a new paragraph.' to the end of Tst4.", + expected_tool_calls=[ + ExpectedToolCall( + func=insert_text_at_end_of_document, + args={ + "document_id": "1-gFGNWmwLxEiKa6NNixLNq3X-phXRMORVZfVTfBg8Sc", + "text_content": "This is a new paragraph.", + }, + ), + ExpectedToolCall( + func=get_document_by_id, + args={ + "document_id": "1eTSWd-5zQds8K9OWYygwtCFMUyuuMize3bh3HaRsKts", + }, + ), + ExpectedToolCall( + func=get_document_by_id, + args={ + "document_id": "1e0rCoT1Yd14WuuEvd3hSUcN_-VD3df4T3Q08uLm3TWc", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="document_id", weight=0.3), + SimilarityCritic(critic_field="text_content", weight=0.3), + ], + additional_messages=additional_messages, + ) + + suite.add_case( + name="Create blank document", + user_message="Create a new Doc titled 'Meeting Notes'.", + expected_tool_calls=[ + ExpectedToolCall( + func=create_blank_document, + args={ + "title": "Meeting Notes", + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="title", weight=1.0), + ], + ) + + suite.add_case( + name="Create document from text", + user_message="Create a new doc called To-Do List with the content 'Buy groceries, Call mom, Finish report'.", + expected_tool_calls=[ + ExpectedToolCall( + func=create_document_from_text, + args={ + "title": "To-Do List", + "text_content": "Buy groceries\nCall mom\nFinish report", + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="title", weight=0.5), + SimilarityCritic(critic_field="text_content", weight=0.5), + ], + ) + + suite.add_case( + name="No tool call case", + user_message="Create a new microsoft word document titled 'My Resume'.", + expected_tool_calls=[], + critics=[], + ) + + return suite + + +@tool_eval() +def search_documents_eval_suite() -> EvalSuite: + """Create an evaluation suite for Google Drive tools.""" + suite = EvalSuite( + name="Google Drive Tools Evaluation", + system_message="You are an AI assistant that can manage Google Drive documents using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Search documents in Google Drive", + user_message="get my 49 most recently created documents, list the ones created most recently first.", + expected_tool_calls=[ + ExpectedToolCall( + func=search_documents, + args={ + "order_by": [OrderBy.CREATED_TIME_DESC.value], + "limit": 49, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="order_by", weight=0.5), + BinaryCritic(critic_field="limit", weight=0.5), + ], + ) + + suite.add_case( + name="Search documents in Google Drive based on document keywords", + user_message="Search the documents that contain the word 'greedy' and the phrase 'hello, world'", + expected_tool_calls=[ + ExpectedToolCall( + func=search_documents, + args={ + "document_contains": ["greedy", "hello, world"], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="document_contains", weight=1.0), + ], + ) + + suite.add_case( + name="Search documents in a specific Google Drive based on document keywords", + user_message="Search the documents that contain the word 'greedy' and the phrase 'hello, world' in the drive with id 'abc123'", + expected_tool_calls=[ + ExpectedToolCall( + func=search_documents, + args={ + "document_contains": ["greedy", "hello, world"], + "search_only_in_shared_drive_id": "abc123", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="search_only_in_shared_drive_id", weight=0.5), + BinaryCritic(critic_field="document_contains", weight=0.5), + ], + ) + + suite.add_case( + name="Search documents in a Google Drive Workspace organization domain based on document keywords", + user_message="Search the documents that contain the phrase 'hello, world' in the organization domain", + expected_tool_calls=[ + ExpectedToolCall( + func=search_documents, + args={ + "document_contains": ["hello, world"], + "include_organization_domain_documents": True, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="include_organization_domain_documents", weight=0.5), + BinaryCritic(critic_field="document_contains", weight=0.5), + ], + ) + + suite.add_case( + name="Search documents in shared drives", + user_message="Search the 5 documents from all drives corpora that nobody has touched in forever, excluding shared drives.", + expected_tool_calls=[ + ExpectedToolCall( + func=search_documents, + args={ + "limit": 5, + "include_shared_drives": False, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="include_shared_drives", weight=0.5), + BinaryCritic(critic_field="limit", weight=0.5), + ], + ) + + suite.add_case( + name="No tool call case", + user_message="List my 10 most recently modified documents that are stored in my Microsoft OneDrive.", + expected_tool_calls=[], + critics=[], + ) + + return suite + + +@tool_eval() +def search_and_retrieve_documents_eval_suite() -> EvalSuite: + """Create an evaluation suite for Google Drive search and retrieve tools.""" + suite = EvalSuite( + name="Google Drive Tools Evaluation", + system_message="You are an AI assistant that can manage Google Drive documents using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Search and retrieve (write summary)", + user_message="Write a summary of the documents in my Google Drive about 'MX Engineering'", + expected_tool_calls=[ + ExpectedToolCall( + func=search_and_retrieve_documents, + args={ + "document_contains": ["MX Engineering"], + "return_format": DocumentFormat.MARKDOWN, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="document_contains", weight=0.5), + BinaryCritic(critic_field="return_format", weight=0.5), + ], + ) + + suite.add_case( + name="Search and retrieve (project proposal)", + user_message="Display the document contents in HTML format from my Google Drive that contain the phrase 'project proposal'.", + expected_tool_calls=[ + ExpectedToolCall( + func=search_and_retrieve_documents, + args={ + "document_contains": ["project proposal"], + "return_format": DocumentFormat.HTML, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="document_contains", weight=0.5), + BinaryCritic(critic_field="return_format", weight=0.5), + ], + ) + + suite.add_case( + name="Search and retrieve (meeting notes)", + user_message="Retrieve documents that contain both 'meeting notes' and 'budget' in JSON format.", + expected_tool_calls=[ + ExpectedToolCall( + func=search_and_retrieve_documents, + args={ + "document_contains": ["meeting notes", "budget"], + "return_format": DocumentFormat.GOOGLE_API_JSON, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="document_contains", weight=0.5), + BinaryCritic(critic_field="return_format", weight=0.5), + ], + ) + + suite.add_case( + name="Search and retrieve (Q1 report)", + user_message="Show me the content of the documents that mention 'Q1 report' but do not include the expression 'Project XYZ'.", + expected_tool_calls=[ + ExpectedToolCall( + func=search_and_retrieve_documents, + args={ + "document_contains": ["Q1 report"], + "document_not_contains": ["Project XYZ"], + "return_format": DocumentFormat.MARKDOWN, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="document_contains", weight=1 / 3), + BinaryCritic(critic_field="document_not_contains", weight=1 / 3), + BinaryCritic(critic_field="return_format", weight=1 / 3), + ], + ) + + return suite diff --git a/toolkits/google_docs/pyproject.toml b/toolkits/google_docs/pyproject.toml new file mode 100644 index 00000000..58e3f128 --- /dev/null +++ b/toolkits/google_docs/pyproject.toml @@ -0,0 +1,62 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_docs" +version = "2.0.0" +description = "Arcade.dev LLM tools for Google Docs" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "google-api-core>=2.19.1,<3.0.0", + "google-api-python-client>=2.137.0,<3.0.0", + "google-auth>=2.32.0,<3.0.0", + "google-auth-httplib2>=0.2.0,<1.0.0", + "googleapis-common-protos>=1.63.2,<2.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.4,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + +[tool.mypy] +files = [ "arcade_google_docs/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_docs",] diff --git a/toolkits/google_docs/tests/__init__.py b/toolkits/google_docs/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_docs/tests/test_doc_to_markdown.py b/toolkits/google_docs/tests/test_doc_to_markdown.py new file mode 100644 index 00000000..482ba4c9 --- /dev/null +++ b/toolkits/google_docs/tests/test_doc_to_markdown.py @@ -0,0 +1,10 @@ +import pytest + +from arcade_google_docs.doc_to_markdown import convert_document_to_markdown + + +@pytest.mark.asyncio +async def test_convert_document_to_markdown(sample_document_and_expected_formats): + (sample_document, expected_markdown, _) = sample_document_and_expected_formats + markdown = convert_document_to_markdown(sample_document) + assert markdown == expected_markdown diff --git a/toolkits/google_docs/tests/test_google_docs.py b/toolkits/google_docs/tests/test_google_docs.py new file mode 100644 index 00000000..894ce183 --- /dev/null +++ b/toolkits/google_docs/tests/test_google_docs.py @@ -0,0 +1,179 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from arcade_tdk.errors import ToolExecutionError +from googleapiclient.errors import HttpError + +from arcade_google_docs.tools import ( + create_blank_document, + create_document_from_text, + get_document_by_id, + insert_text_at_end_of_document, +) +from arcade_google_docs.utils import build_docs_service + + +@pytest.fixture +def mock_context(): + context = AsyncMock() + context.authorization.token = "mock_token" # noqa: S105 + return context + + +@pytest.fixture +def mock_get_service(): + with patch("arcade_google_docs.tools.get." + build_docs_service.__name__) as mock_build_service: + yield mock_build_service.return_value + + +@pytest.fixture +def mock_update_service(): + with patch( + "arcade_google_docs.tools.update." + build_docs_service.__name__ + ) as mock_build_service: + yield mock_build_service.return_value + + +@pytest.fixture +def mock_create_service(): + with patch( + "arcade_google_docs.tools.create." + build_docs_service.__name__ + ) as mock_build_service: + yield mock_build_service.return_value + + +@pytest.mark.asyncio +async def test_get_document_by_id_success(mock_context, mock_get_service): + # Mock the service.documents().get().execute() method + mock_get_service.documents.return_value.get.return_value.execute.return_value = { + "body": {"content": [{"endIndex": 1, "paragraph": {}}]}, + "documentId": "test_document_id", + "title": "Test Document", + } + + result = await get_document_by_id(mock_context, "test_document_id") + + assert result["documentId"] == "test_document_id" + assert result["title"] == "Test Document" + + +@pytest.mark.asyncio +async def test_get_document_by_id_http_error(mock_context, mock_get_service): + # Simulate HttpError + mock_get_service.documents.return_value.get.return_value.execute.side_effect = HttpError( + resp=AsyncMock(status=404), content=b'{"error": {"message": "Not Found"}}' + ) + + with pytest.raises(ToolExecutionError, match="Error in execution of GetDocumentById"): + await get_document_by_id(mock_context, "invalid_document_id") + + +@pytest.mark.asyncio +async def test_insert_text_at_end_of_document_success(mock_context, mock_update_service): + # Mock get_document_by_id to return a document with endIndex + with patch( + "arcade_google_docs.tools.update.get_document_by_id", + return_value={"body": {"content": [{"endIndex": 1, "paragraph": {}}]}}, + ): + # Mock the service.documents().batchUpdate().execute() method + mock_update_service.documents.return_value.batchUpdate.return_value.execute.return_value = { + "documentId": "test_document_id", + "replies": [], + } + + result = await insert_text_at_end_of_document( + mock_context, "test_document_id", "Sample text" + ) + + assert result["documentId"] == "test_document_id" + + +@pytest.mark.asyncio +async def test_insert_text_at_end_of_document_http_error(mock_context, mock_update_service): + with patch( + "arcade_google_docs.tools.update.get_document_by_id", + return_value={"body": {"content": [{"endIndex": 1, "paragraph": {}}]}}, + ): + # Simulate HttpError during batchUpdate + mock_update_service.documents.return_value.batchUpdate.return_value.execute.side_effect = ( + HttpError(resp=AsyncMock(status=400), content=b'{"error": {"message": "Bad Request"}}') + ) + + with pytest.raises( + ToolExecutionError, match="Error in execution of InsertTextAtEndOfDocument" + ): + await insert_text_at_end_of_document(mock_context, "test_document_id", "Sample text") + + +@pytest.mark.asyncio +async def test_create_blank_document_success(mock_context, mock_create_service): + # Mock the service.documents().create().execute() method + mock_create_service.documents.return_value.create.return_value.execute.return_value = { + "documentId": "new_document_id", + "title": "New Document", + } + + result = await create_blank_document(mock_context, "New Document") + + assert result["documentId"] == "new_document_id" + assert result["title"] == "New Document" + assert "documentUrl" in result + + +@pytest.mark.asyncio +async def test_create_blank_document_http_error(mock_context, mock_create_service): + # Simulate HttpError during create + mock_create_service.documents.return_value.create.return_value.execute.side_effect = HttpError( + resp=AsyncMock(status=403), content=b'{"error": {"message": "Forbidden"}}' + ) + + with pytest.raises(ToolExecutionError, match="Error in execution of CreateBlankDocument"): + await create_blank_document(mock_context, "New Document") + + +@pytest.mark.asyncio +async def test_create_document_from_text_success(mock_context, mock_create_service): + with patch( + "arcade_google_docs.tools.create." + create_blank_document.__name__ + ) as mock_create_blank_document: + # Mock create_blank_document + mock_create_blank_document.return_value = { + "documentId": "new_document_id", + "title": "New Document", + } + + # Mock the service.documents().batchUpdate().execute() method + mock_create_service.documents.return_value.batchUpdate.return_value.execute.return_value = { + "documentId": "new_document_id", + "replies": [], + } + + result = await create_document_from_text(mock_context, "New Document", "Hello, World!") + + assert result["documentId"] == "new_document_id" + assert result["title"] == "New Document" + assert "documentUrl" in result + + +@pytest.mark.asyncio +async def test_create_document_from_text_http_error(mock_context, mock_create_service): + with patch( + "arcade_google_docs.tools.create." + create_blank_document.__name__ + ) as mock_create_blank_document: + # Mock create_blank_document + mock_create_blank_document.return_value = { + "documentId": "new_document_id", + "title": "New Document", + } + + # Simulate HttpError during batchUpdate + mock_create_service.documents.return_value.batchUpdate.return_value.execute.side_effect = ( + HttpError( + resp=AsyncMock(status=500), content=b'{"error": {"message": "Internal Error"}}' + ) + ) + + with pytest.raises( + ToolExecutionError, match="Error in execution of CreateDocumentFromText" + ): + await create_document_from_text(mock_context, "New Document", "Hello, World!") diff --git a/toolkits/google_docs/tests/test_search.py b/toolkits/google_docs/tests/test_search.py new file mode 100644 index 00000000..38ad1a81 --- /dev/null +++ b/toolkits/google_docs/tests/test_search.py @@ -0,0 +1,276 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from arcade_tdk.errors import ToolExecutionError +from googleapiclient.errors import HttpError + +from arcade_google_docs.enum import Corpora, DocumentFormat, OrderBy +from arcade_google_docs.templates import optional_file_picker_instructions_template +from arcade_google_docs.tools import ( + search_and_retrieve_documents, + search_documents, +) +from arcade_google_docs.utils import build_drive_service + + +@pytest.fixture +def mock_context(): + context = AsyncMock() + context.authorization.token = "mock_token" # noqa: S105 + context.get_metadata.side_effect = lambda key: { + "client_id": "123456789-abcdefg.apps.googleusercontent.com", + "coordinator_url": "https://coordinator.example.com", + }.get(key.value if hasattr(key, "value") else key) + return context + + +@pytest.fixture +def mock_service(): + with patch( + "arcade_google_docs.tools.search." + build_drive_service.__name__ + ) as mock_build_service: + yield mock_build_service.return_value + + +@pytest.mark.asyncio +async def test_search_documents_success(mock_context, mock_service): + # Mock the service.files().list().execute() method + mock_service.files.return_value.list.return_value.execute.side_effect = [ + { + "files": [ + {"id": "file1", "name": "Document 1"}, + {"id": "file2", "name": "Document 2"}, + ], + "nextPageToken": None, + } + ] + + # Mock the generate_google_file_picker_url function + with patch( + "arcade_google_docs.tools.search.generate_google_file_picker_url" + ) as mock_file_picker: + mock_file_picker.return_value = { + "url": "https://coordinator.example.com/google/drive_picker?config=test_config", + "llm_instructions": optional_file_picker_instructions_template.format( + url="https://coordinator.example.com/google/drive_picker?config=test_config" + ), + } + + result = await search_documents(mock_context, limit=2) + + assert result["documents_count"] == 2 + assert len(result["documents"]) == 2 + assert result["documents"][0]["id"] == "file1" + assert result["documents"][1]["id"] == "file2" + + +@pytest.mark.asyncio +async def test_search_documents_pagination(mock_context, mock_service): + # Simulate multiple pages + mock_service.files.return_value.list.return_value.execute.side_effect = [ + { + "files": [{"id": f"file{i}", "name": f"Document {i}"} for i in range(1, 11)], + "nextPageToken": "token1", + }, + { + "files": [{"id": f"file{i}", "name": f"Document {i}"} for i in range(11, 21)], + "nextPageToken": None, + }, + ] + + # Mock the generate_google_file_picker_url function + with patch( + "arcade_google_docs.tools.search.generate_google_file_picker_url" + ) as mock_file_picker: + mock_file_picker.return_value = { + "url": "https://coordinator.example.com/google/drive_picker?config=test_config", + "llm_instructions": optional_file_picker_instructions_template.format( + url="https://coordinator.example.com/google/drive_picker?config=test_config" + ), + } + + result = await search_documents(mock_context, limit=15) + + assert result["documents_count"] == 15 + assert len(result["documents"]) == 15 + assert result["documents"][0]["id"] == "file1" + assert result["documents"][-1]["id"] == "file15" + + +@pytest.mark.asyncio +async def test_search_documents_http_error(mock_context, mock_service): + # Simulate HttpError + mock_service.files.return_value.list.return_value.execute.side_effect = HttpError( + resp=AsyncMock(status=403), content=b'{"error": {"message": "Forbidden"}}' + ) + + with pytest.raises( + ToolExecutionError, match=f"Error in execution of {search_documents.__tool_name__}" + ): + await search_documents(mock_context) + + +@pytest.mark.asyncio +async def test_search_documents_unexpected_error(mock_context, mock_service): + # Simulate unexpected exception + mock_service.files.return_value.list.return_value.execute.side_effect = Exception( + "Unexpected error" + ) + + with pytest.raises( + ToolExecutionError, match=f"Error in execution of {search_documents.__tool_name__}" + ): + await search_documents(mock_context) + + +@pytest.mark.asyncio +async def test_search_documents_in_organization_domains(mock_context, mock_service): + # Mock the service.files().list().execute() method + mock_service.files.return_value.list.return_value.execute.side_effect = [ + { + "files": [ + {"id": "file1", "name": "Document 1"}, + ], + "nextPageToken": None, + } + ] + + # Mock the generate_google_file_picker_url function + with patch( + "arcade_google_docs.tools.search.generate_google_file_picker_url" + ) as mock_file_picker: + mock_file_picker.return_value = { + "url": "https://coordinator.example.com/google/drive_picker?config=test_config", + "llm_instructions": optional_file_picker_instructions_template.format( + url="https://coordinator.example.com/google/drive_picker?config=test_config" + ), + } + + result = await search_documents( + mock_context, + order_by=OrderBy.MODIFIED_TIME_DESC, + include_shared_drives=False, + include_organization_domain_documents=True, + limit=1, + ) + + assert result["documents_count"] == 1 + mock_service.files.return_value.list.assert_called_with( + q="(mimeType = 'application/vnd.google-apps.document' and trashed = false)", + corpora=Corpora.DOMAIN.value, + pageSize=1, + orderBy=OrderBy.MODIFIED_TIME_DESC.value, + includeItemsFromAllDrives="true", + supportsAllDrives="true", + ) + + +@pytest.mark.asyncio +@patch("arcade_google_docs.tools.search.search_documents") +@patch("arcade_google_docs.tools.search.get_document_by_id") +async def test_search_and_retrieve_documents_in_markdown_format( + mock_get_document_by_id, + mock_search_documents, + mock_context, + sample_document_and_expected_formats, +): + (sample_document, expected_markdown, _) = sample_document_and_expected_formats + mock_search_documents.return_value = { + "documents_count": 1, + "documents": [{"id": sample_document["documentId"], "title": sample_document["title"]}], + } + mock_get_document_by_id.return_value = sample_document + + # Mock the generate_google_file_picker_url function + with patch( + "arcade_google_docs.tools.search.generate_google_file_picker_url" + ) as mock_file_picker: + mock_file_picker.return_value = { + "url": "https://coordinator.example.com/google/drive_picker?config=test_config", + "llm_instructions": optional_file_picker_instructions_template.format( + url="https://coordinator.example.com/google/drive_picker?config=test_config" + ), + } + + result = await search_and_retrieve_documents( + mock_context, + document_contains=[sample_document["title"]], + return_format=DocumentFormat.MARKDOWN, + ) + + assert result["documents_count"] == 1 + assert result["documents"][0] == expected_markdown + + +@pytest.mark.asyncio +@patch("arcade_google_docs.tools.search.search_documents") +@patch("arcade_google_docs.tools.search.get_document_by_id") +async def test_search_and_retrieve_documents_in_html_format( + mock_get_document_by_id, + mock_search_documents, + mock_context, + sample_document_and_expected_formats, +): + (sample_document, _, expected_html) = sample_document_and_expected_formats + mock_search_documents.return_value = { + "documents_count": 1, + "documents": [{"id": sample_document["documentId"], "title": sample_document["title"]}], + } + mock_get_document_by_id.return_value = sample_document + + # Mock the generate_google_file_picker_url function + with patch( + "arcade_google_docs.tools.search.generate_google_file_picker_url" + ) as mock_file_picker: + mock_file_picker.return_value = { + "url": "https://coordinator.example.com/google/drive_picker?config=test_config", + "llm_instructions": optional_file_picker_instructions_template.format( + url="https://coordinator.example.com/google/drive_picker?config=test_config" + ), + } + + result = await search_and_retrieve_documents( + mock_context, + document_contains=[sample_document["title"]], + return_format=DocumentFormat.HTML, + ) + + assert result["documents_count"] == 1 + assert result["documents"][0] == expected_html + + +@pytest.mark.asyncio +@patch("arcade_google_docs.tools.search.search_documents") +@patch("arcade_google_docs.tools.search.get_document_by_id") +async def test_search_and_retrieve_documents_in_google_json_format( + mock_get_document_by_id, + mock_search_documents, + mock_context, + sample_document_and_expected_formats, +): + (sample_document, _, _) = sample_document_and_expected_formats + mock_search_documents.return_value = { + "documents_count": 1, + "documents": [{"id": sample_document["documentId"], "title": sample_document["title"]}], + } + mock_get_document_by_id.return_value = sample_document + + # Mock the generate_google_file_picker_url function + with patch( + "arcade_google_docs.tools.search.generate_google_file_picker_url" + ) as mock_file_picker: + mock_file_picker.return_value = { + "url": "https://coordinator.example.com/google/drive_picker?config=test_config", + "llm_instructions": optional_file_picker_instructions_template.format( + url="https://coordinator.example.com/google/drive_picker?config=test_config" + ), + } + + result = await search_and_retrieve_documents( + mock_context, + document_contains=[sample_document["title"]], + return_format=DocumentFormat.GOOGLE_API_JSON, + ) + + assert result["documents_count"] == 1 + assert result["documents"][0] == sample_document diff --git a/toolkits/google_drive/.pre-commit-config.yaml b/toolkits/google_drive/.pre-commit-config.yaml new file mode 100644 index 00000000..28d8a695 --- /dev/null +++ b/toolkits/google_drive/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_drive/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_drive/.ruff.toml b/toolkits/google_drive/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/google_drive/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_drive/Makefile b/toolkits/google_drive/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_drive/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_drive/arcade_google_drive/__init__.py b/toolkits/google_drive/arcade_google_drive/__init__.py new file mode 100644 index 00000000..eff625d6 --- /dev/null +++ b/toolkits/google_drive/arcade_google_drive/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_drive.tools import generate_google_file_picker_url, get_file_tree_structure + +__all__ = ["generate_google_file_picker_url", "get_file_tree_structure"] diff --git a/toolkits/google_drive/arcade_google_drive/enums.py b/toolkits/google_drive/arcade_google_drive/enums.py new file mode 100644 index 00000000..54b1a22b --- /dev/null +++ b/toolkits/google_drive/arcade_google_drive/enums.py @@ -0,0 +1,116 @@ +from enum import Enum + + +class Corpora(str, Enum): + """ + Bodies of items (files/documents) to which the query applies. + Prefer 'user' or 'drive' to 'allDrives' for efficiency. + By default, corpora is set to 'user'. + """ + + USER = "user" + DOMAIN = "domain" + DRIVE = "drive" + ALL_DRIVES = "allDrives" + + +class OrderBy(str, Enum): + """ + Sort keys for ordering files in Google Drive. + Each key has both ascending and descending options. + """ + + CREATED_TIME = ( + # When the file was created (ascending) + "createdTime" + ) + CREATED_TIME_DESC = ( + # When the file was created (descending) + "createdTime desc" + ) + FOLDER = ( + # The folder ID, sorted using alphabetical ordering (ascending) + "folder" + ) + FOLDER_DESC = ( + # The folder ID, sorted using alphabetical ordering (descending) + "folder desc" + ) + MODIFIED_BY_ME_TIME = ( + # The last time the file was modified by the user (ascending) + "modifiedByMeTime" + ) + MODIFIED_BY_ME_TIME_DESC = ( + # The last time the file was modified by the user (descending) + "modifiedByMeTime desc" + ) + MODIFIED_TIME = ( + # The last time the file was modified by anyone (ascending) + "modifiedTime" + ) + MODIFIED_TIME_DESC = ( + # The last time the file was modified by anyone (descending) + "modifiedTime desc" + ) + NAME = ( + # The name of the file, sorted using alphabetical ordering (e.g., 1, 12, 2, 22) (ascending) + "name" + ) + NAME_DESC = ( + # The name of the file, sorted using alphabetical ordering (e.g., 1, 12, 2, 22) (descending) + "name desc" + ) + NAME_NATURAL = ( + # The name of the file, sorted using natural sort ordering (e.g., 1, 2, 12, 22) (ascending) + "name_natural" + ) + NAME_NATURAL_DESC = ( + # The name of the file, sorted using natural sort ordering (e.g., 1, 2, 12, 22) (descending) + "name_natural desc" + ) + QUOTA_BYTES_USED = ( + # The number of storage quota bytes used by the file (ascending) + "quotaBytesUsed" + ) + QUOTA_BYTES_USED_DESC = ( + # The number of storage quota bytes used by the file (descending) + "quotaBytesUsed desc" + ) + RECENCY = ( + # The most recent timestamp from the file's date-time fields (ascending) + "recency" + ) + RECENCY_DESC = ( + # The most recent timestamp from the file's date-time fields (descending) + "recency desc" + ) + SHARED_WITH_ME_TIME = ( + # When the file was shared with the user, if applicable (ascending) + "sharedWithMeTime" + ) + SHARED_WITH_ME_TIME_DESC = ( + # When the file was shared with the user, if applicable (descending) + "sharedWithMeTime desc" + ) + STARRED = ( + # Whether the user has starred the file (ascending) + "starred" + ) + STARRED_DESC = ( + # Whether the user has starred the file (descending) + "starred desc" + ) + VIEWED_BY_ME_TIME = ( + # The last time the file was viewed by the user (ascending) + "viewedByMeTime" + ) + VIEWED_BY_ME_TIME_DESC = ( + # The last time the file was viewed by the user (descending) + "viewedByMeTime desc" + ) + + +class DocumentFormat(str, Enum): + MARKDOWN = "markdown" + HTML = "html" + GOOGLE_API_JSON = "google_api_json" diff --git a/toolkits/google_drive/arcade_google_drive/templates.py b/toolkits/google_drive/arcade_google_drive/templates.py new file mode 100644 index 00000000..55de8614 --- /dev/null +++ b/toolkits/google_drive/arcade_google_drive/templates.py @@ -0,0 +1,5 @@ +optional_file_picker_instructions_template = ( + "Ensure the user knows that they have the option to select and grant access permissions to " + "additional files and folders via the Google Drive File Picker. " + "The user can pick additional files and folders via the following link: {url}" +) diff --git a/toolkits/google_drive/arcade_google_drive/tools/__init__.py b/toolkits/google_drive/arcade_google_drive/tools/__init__.py new file mode 100644 index 00000000..bc30b32e --- /dev/null +++ b/toolkits/google_drive/arcade_google_drive/tools/__init__.py @@ -0,0 +1,6 @@ +from arcade_google_drive.tools.drive import generate_google_file_picker_url, get_file_tree_structure + +__all__ = [ + "generate_google_file_picker_url", + "get_file_tree_structure", +] diff --git a/toolkits/google_drive/arcade_google_drive/tools/drive.py b/toolkits/google_drive/arcade_google_drive/tools/drive.py new file mode 100644 index 00000000..03a3fd80 --- /dev/null +++ b/toolkits/google_drive/arcade_google_drive/tools/drive.py @@ -0,0 +1,167 @@ +import base64 +import json +from typing import Annotated + +from arcade_tdk import ToolContext, ToolMetadataKey, tool +from arcade_tdk.auth import Google +from arcade_tdk.errors import ToolExecutionError +from googleapiclient.errors import HttpError + +from arcade_google_drive.enums import OrderBy +from arcade_google_drive.templates import optional_file_picker_instructions_template +from arcade_google_drive.utils import ( + build_drive_service, + build_file_tree, + build_file_tree_request_params, +) + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/drive.file"], + ), + requires_metadata=[ToolMetadataKey.CLIENT_ID, ToolMetadataKey.COORDINATOR_URL], +) +async def get_file_tree_structure( + context: ToolContext, + include_shared_drives: Annotated[ + bool, "Whether to include shared drives in the file tree structure. Defaults to False." + ] = False, + restrict_to_shared_drive_id: Annotated[ + str | None, + "If provided, only include files from this shared drive in the file tree structure. " + "Defaults to None, which will include files and folders from all drives.", + ] = None, + include_organization_domain_documents: Annotated[ + bool, + "Whether to include documents from the organization's domain. This is applicable to admin " + "users who have permissions to view organization-wide documents in a Google Workspace " + "account. Defaults to False.", + ] = False, + order_by: Annotated[ + list[OrderBy] | None, + "Sort order. Defaults to listing the most recently modified documents first", + ] = None, + limit: Annotated[ + int | None, + "The number of files and folders to list. Defaults to None, " + "which will list all files and folders.", + ] = None, +) -> Annotated[ + dict, + "A dictionary containing the file/folder tree structure in the user's Google Drive", +]: + """ + Get the file/folder tree structure of the user's Google Drive. + """ + service = build_drive_service(context.get_auth_token_or_empty()) + + keep_paginating = True + page_token = None + files = {} + file_tree: dict[str, list[dict]] = {"My Drive": []} + + params = build_file_tree_request_params( + order_by, + page_token, + limit, + include_shared_drives, + restrict_to_shared_drive_id, + include_organization_domain_documents, + ) + + while keep_paginating: + # Get a list of files + results = service.files().list(**params).execute() + + # Update page token + page_token = results.get("nextPageToken") + params["pageToken"] = page_token + keep_paginating = page_token is not None + + for file in results.get("files", []): + files[file["id"]] = file + + if not files: + return {"drives": []} + + file_tree = build_file_tree(files) + + drives = [] + + for drive_id, files in file_tree.items(): # type: ignore[assignment] + if drive_id == "My Drive": + drive = {"name": "My Drive", "children": files} + else: + try: + drive_details = service.drives().get(driveId=drive_id).execute() + drive_name = drive_details.get("name", "Shared Drive (name unavailable)") + except HttpError as e: + drive_name = ( + f"Shared Drive (name unavailable: 'HttpError {e.status_code}: {e.reason}')" + ) + + drive = {"name": drive_name, "id": drive_id, "children": files} + + drives.append(drive) + + file_picker_response = generate_google_file_picker_url( + context, + ) + + return { + "drives": drives, + "file_picker": { + "url": file_picker_response["url"], + "llm_instructions": optional_file_picker_instructions_template.format( + url=file_picker_response["url"] + ), + }, + } + + +@tool( + requires_auth=Google(), + requires_metadata=[ToolMetadataKey.CLIENT_ID, ToolMetadataKey.COORDINATOR_URL], +) +def generate_google_file_picker_url( + context: ToolContext, +) -> Annotated[dict, "Google File Picker URL for user file selection and permission granting"]: + """Generate a Google File Picker URL for user-driven file selection and authorization. + + This tool generates a URL that directs the end-user to a Google File Picker interface where + where they can select or upload Google Drive files. Users can grant permission to access their + Drive files, providing a secure and authorized way to interact with their files. + + This is particularly useful when prior tools (e.g., those accessing or modifying + Google Docs, Google Sheets, etc.) encountered failures due to file non-existence + (Requested entity was not found) or permission errors. Once the user completes the file + picker flow, the prior tool can be retried. + """ + client_id = context.get_metadata(ToolMetadataKey.CLIENT_ID) + client_id_parts = client_id.split("-") + if not client_id_parts: + raise ToolExecutionError( + message="Invalid Google Client ID", + developer_message=f"Google Client ID '{client_id}' is not valid", + ) + app_id = client_id_parts[0] + cloud_coordinator_url = context.get_metadata(ToolMetadataKey.COORDINATOR_URL).strip("/") + + config = { + "auth": { + "client_id": client_id, + "app_id": app_id, + }, + } + config_json = json.dumps(config) + config_base64 = base64.urlsafe_b64encode(config_json.encode("utf-8")).decode("utf-8") + url = f"{cloud_coordinator_url}/google/drive_picker?config={config_base64}" + + return { + "url": url, + "llm_instructions": ( + "Instruct the user to click the following link to open the Google Drive File Picker. " + "This will allow them to select files and grant access permissions: {url}" + ), + } diff --git a/toolkits/google_drive/arcade_google_drive/utils.py b/toolkits/google_drive/arcade_google_drive/utils.py new file mode 100644 index 00000000..2f3d3f7d --- /dev/null +++ b/toolkits/google_drive/arcade_google_drive/utils.py @@ -0,0 +1,114 @@ +import logging +from typing import Any + +from google.oauth2.credentials import Credentials +from googleapiclient.discovery import Resource, build + +from arcade_google_drive.enums import Corpora, OrderBy + +## Set up basic configuration for logging to the console with DEBUG level and a specific format. +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +def build_drive_service(auth_token: str | None) -> Resource: # type: ignore[no-any-unimported] + """ + Build a Drive service object. + """ + auth_token = auth_token or "" + return build("drive", "v3", credentials=Credentials(auth_token)) + + +def build_file_tree_request_params( + order_by: list[OrderBy] | None, + page_token: str | None, + limit: int | None, + include_shared_drives: bool, + restrict_to_shared_drive_id: str | None, + include_organization_domain_documents: bool, +) -> dict[str, Any]: + if order_by is None: + order_by = [OrderBy.MODIFIED_TIME_DESC] + elif isinstance(order_by, OrderBy): + order_by = [order_by] + + params = { + "q": "trashed = false", + "corpora": Corpora.USER.value, + "pageToken": page_token, + "fields": ( + "files(id, name, parents, mimeType, driveId, size, createdTime, modifiedTime, owners)" + ), + "orderBy": ",".join([item.value for item in order_by]), + } + + if limit: + params["pageSize"] = str(limit) + + if ( + include_shared_drives + or restrict_to_shared_drive_id + or include_organization_domain_documents + ): + params["includeItemsFromAllDrives"] = "true" + params["supportsAllDrives"] = "true" + + if restrict_to_shared_drive_id: + params["driveId"] = restrict_to_shared_drive_id + params["corpora"] = Corpora.DRIVE.value + + if include_organization_domain_documents: + params["corpora"] = Corpora.DOMAIN.value + + return params + + +def build_file_tree(files: dict[str, Any]) -> dict[str, Any]: + file_tree: dict[str, Any] = {} + + for file in files.values(): + owners = file.get("owners", []) + if owners: + owners = [ + {"name": owner.get("displayName", ""), "email": owner.get("emailAddress", "")} + for owner in owners + ] + file["owners"] = owners + + if "size" in file: + file["size"] = {"value": int(file["size"]), "unit": "bytes"} + + # Although "parents" is a list, a file can only have one parent + try: + parent_id = file["parents"][0] + del file["parents"] + except (KeyError, IndexError): + parent_id = None + + # Determine the file's Drive ID + if "driveId" in file: + drive_id = file["driveId"] + del file["driveId"] + # If a shared drive id is not present, the file is in "My Drive" + else: + drive_id = "My Drive" + + if drive_id not in file_tree: + file_tree[drive_id] = [] + + # Root files will have the Drive's id as the parent. If the parent id is not in the files + # list, the file must be at drive's root + if parent_id not in files: + file_tree[drive_id].append(file) + + # Associate the file with its parent + else: + if "children" not in files[parent_id]: + files[parent_id]["children"] = [] + files[parent_id]["children"].append(file) + + return file_tree diff --git a/toolkits/google_drive/conftest.py b/toolkits/google_drive/conftest.py new file mode 100644 index 00000000..c579763d --- /dev/null +++ b/toolkits/google_drive/conftest.py @@ -0,0 +1,197 @@ +import pytest + + +@pytest.fixture +def sample_drive_file_tree_request_responses() -> tuple[dict, list]: + files_list = { + "files": [ + # Shared Drive 1 files and folders + { + "id": "19WVyQndQsc0AxxfdrIt5CvDQd6r-BvpqnB8bWZoL7Xk", + "name": "shared-1-folder-1-doc-1", + "mimeType": "application/vnd.google-apps.document", + "parents": ["1dCOCdPxhTqiB3j3bWrIWM692ZbL8dyjt"], + "createdTime": "2025-02-26T00:28:20.571Z", + "modifiedTime": "2025-02-26T00:28:30.773Z", + "driveId": "0AFqcR6obkydtUk9PVA", + "size": "1024", + }, + { + "id": "1dCOCdPxhTqiB3j3bWrIWM692ZbL8dyjt", + "name": "shared-1-folder-1", + "mimeType": "application/vnd.google-apps.folder", + "parents": ["0AFqcR6obkydtUk9PVA"], + "createdTime": "2025-02-26T00:27:45.526Z", + "modifiedTime": "2025-02-26T00:27:45.526Z", + "driveId": "0AFqcR6obkydtUk9PVA", + }, + { + "id": "1didt_h-tDjuJ-dmYtHUSyOCPci30K_kSszvg0G3tKBM", + "name": "shared-1-doc-1", + "mimeType": "application/vnd.google-apps.document", + "parents": ["0AFqcR6obkydtUk9PVA"], + "createdTime": "2025-02-26T00:27:19.287Z", + "modifiedTime": "2025-02-26T00:27:26.079Z", + "driveId": "0AFqcR6obkydtUk9PVA", + "size": "1024", + }, + # My Drive files and folders + { + "id": "1vB6sv0MD0hYSraYvWU_fcci3GN_-Jf4g-LfyXdG8ZMo", + "name": "The Birth of MX Engineering", + "mimeType": "application/vnd.google-apps.document", + "parents": ["0AIbBwO2hjeHqUk9PVA"], + "createdTime": "2025-01-24T06:34:22.305Z", + "modifiedTime": "2025-02-25T21:54:30.632Z", + "owners": [ + { + "kind": "drive#user", + "displayName": "one_new_tool_everyday", + "photoLink": "https://lh3.googleusercontent.com/a-/photo.png", + "me": True, + "permissionId": "00356981722324419750", + "emailAddress": "one_new_tool_everyday@arcade.dev", + } + ], + "size": "6634", + }, + { + "id": "1wv2dmYo0skJTI59ZIcwH9vm-wt7psMwXTvihuEGeHeI", + "name": "test document 1.1.1", + "mimeType": "application/vnd.google-apps.document", + "parents": ["1J92V9yvVWm_uNHq3CCY4wyG1H9B6iiwO"], + "createdTime": "2025-02-25T17:59:03.325Z", + "modifiedTime": "2025-02-25T17:59:11.445Z", + "owners": [ + { + "kind": "drive#user", + "displayName": "one_new_tool_everyday", + "photoLink": "https://lh3.googleusercontent.com/a-/photo.png", + "me": True, + "permissionId": "00356981722324419750", + "emailAddress": "one_new_tool_everyday@arcade.dev", + } + ], + "size": "1024", + }, + { + "id": "1J92V9yvVWm_uNHq3CCY4wyG1H9B6iiwO", + "name": "test folder 1.1", + "mimeType": "application/vnd.google-apps.folder", + "parents": ["1gqioaHG53jPVeJN5gBpHoO-GWtwiJcLo"], + "createdTime": "2025-02-25T17:58:58.987Z", + "modifiedTime": "2025-02-25T17:58:58.987Z", + "owners": [ + { + "kind": "drive#user", + "displayName": "one_new_tool_everyday", + "photoLink": "https://lh3.googleusercontent.com/a-/photo.png", + "me": True, + "permissionId": "00356981722324419750", + "emailAddress": "one_new_tool_everyday@arcade.dev", + } + ], + }, + { + "id": "1DSmL7d07kjT6b6L-t4JIT06ElUbZ1q0K6_gEpn_UGZ8", + "name": "test document 1.2", + "mimeType": "application/vnd.google-apps.document", + "parents": ["1gqioaHG53jPVeJN5gBpHoO-GWtwiJcLo"], + "createdTime": "2025-02-25T17:58:38.628Z", + "modifiedTime": "2025-02-25T17:58:46.713Z", + "owners": [ + { + "kind": "drive#user", + "displayName": "one_new_tool_everyday", + "photoLink": "https://lh3.googleusercontent.com/a-/photo.png", + "me": True, + "permissionId": "00356981722324419750", + "emailAddress": "one_new_tool_everyday@arcade.dev", + } + ], + "size": "1024", + }, + { + "id": "1Fcxz7HsyO2Zyc-5DTD3zBQnaVrZwD29BP9KD9rPnYfE", + "name": "test document 1.1", + "mimeType": "application/vnd.google-apps.document", + "parents": ["1gqioaHG53jPVeJN5gBpHoO-GWtwiJcLo"], + "createdTime": "2025-02-25T17:57:53.850Z", + "modifiedTime": "2025-02-25T17:58:28.745Z", + "owners": [ + { + "kind": "drive#user", + "displayName": "one_new_tool_everyday", + "photoLink": "https://lh3.googleusercontent.com/a-/photo.png", + "me": True, + "permissionId": "00356981722324419750", + "emailAddress": "one_new_tool_everyday@arcade.dev", + } + ], + "size": "1024", + }, + { + "id": "1gqioaHG53jPVeJN5gBpHoO-GWtwiJcLo", + "name": "test folder 1", + "mimeType": "application/vnd.google-apps.folder", + "parents": ["0AIbBwO2hjeHqUk9PVA"], + "createdTime": "2025-02-25T17:57:46.036Z", + "modifiedTime": "2025-02-25T17:57:46.036Z", + "owners": [ + { + "kind": "drive#user", + "displayName": "one_new_tool_everyday", + "photoLink": "https://lh3.googleusercontent.com/a-/photo.png", + "me": True, + "permissionId": "00356981722324419750", + "emailAddress": "one_new_tool_everyday@arcade.dev", + } + ], + }, + { + "id": "16PUe97yGQeOjQgrgd54iCoxzid4SEvu_J33P_ELd5r8", + "name": "Hello world presentation", + "mimeType": "application/vnd.google-apps.presentation", + "createdTime": "2025-02-18T20:48:52.786Z", + "modifiedTime": "2025-02-19T23:31:20.483Z", + "owners": [ + { + "kind": "drive#user", + "displayName": "john.doe", + "photoLink": "https://lh3.googleusercontent.com/a-/photo.png", + "me": False, + "permissionId": "06420661154928749996", + "emailAddress": "john.doe@arcade.dev", + } + ], + "size": "15774558", + }, + { + "id": "1nG7lSvIyK05N9METPczVJa4iGgE7uoo-A6zpqjpUsDY", + "name": "Shared doc 1", + "mimeType": "application/vnd.google-apps.document", + "createdTime": "2025-02-19T18:51:44.622Z", + "modifiedTime": "2025-02-19T19:30:39.773Z", + "owners": [ + { + "kind": "drive#user", + "displayName": "theboss", + "photoLink": "https://lh3.googleusercontent.com/a-/photo.png", + "me": False, + "permissionId": "11571864250637401873", + "emailAddress": "theboss@arcade.dev", + } + ], + "size": "2700", + }, + ], + } + + drives_get = [ + { + "id": "0AFqcR6obkydtUk9PVA", + "name": "Shared Drive 1", + } + ] + + return files_list, drives_get diff --git a/toolkits/google_drive/evals/eval_google_drive.py b/toolkits/google_drive/evals/eval_google_drive.py new file mode 100644 index 00000000..d49fdff5 --- /dev/null +++ b/toolkits/google_drive/evals/eval_google_drive.py @@ -0,0 +1,131 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_google_drive +from arcade_google_drive.tools import ( + get_file_tree_structure, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_module(arcade_google_drive) + + +@tool_eval() +def get_file_tree_structure_eval_suite() -> EvalSuite: + """Create an evaluation suite for Google Drive tools.""" + suite = EvalSuite( + name="Google Drive Tools Evaluation", + system_message="You are an AI assistant that can manage Google Drive documents using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="get my google drive's file tree structure including shared drives", + user_message="get my google drive's file tree structure including shared drives", + expected_tool_calls=[ + ExpectedToolCall( + func=get_file_tree_structure, + args={ + "restrict_to_shared_drive_id": None, + "include_shared_drives": True, + "include_organization_domain_documents": False, + "order_by": None, + "limit": None, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="include_shared_drives", weight=0.5), + BinaryCritic(critic_field="restrict_to_shared_drive_id", weight=0.5 / 4), + BinaryCritic(critic_field="include_organization_domain_documents", weight=0.5 / 4), + BinaryCritic(critic_field="order_by", weight=0.5 / 4), + BinaryCritic(critic_field="limit", weight=0.5 / 4), + ], + ) + + suite.add_case( + name="get my google drive's file tree structure without shared drives", + user_message="get my google drive's file tree structure without shared drives", + expected_tool_calls=[ + ExpectedToolCall( + func=get_file_tree_structure, + args={ + "restrict_to_shared_drive_id": None, + "include_shared_drives": False, + "include_organization_domain_documents": False, + "order_by": None, + "limit": None, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="include_shared_drives", weight=0.5), + BinaryCritic(critic_field="restrict_to_shared_drive_id", weight=0.5 / 4), + BinaryCritic(critic_field="include_organization_domain_documents", weight=0.5 / 4), + BinaryCritic(critic_field="order_by", weight=0.5 / 4), + BinaryCritic(critic_field="limit", weight=0.5 / 4), + ], + ) + + suite.add_case( + name="what are the files in the folder 'hello world' in my google drive?", + user_message="what are the files in the folder 'hello world' in my google drive?", + expected_tool_calls=[ + ExpectedToolCall( + func=get_file_tree_structure, + args={ + "restrict_to_shared_drive_id": None, + "include_shared_drives": False, + "include_organization_domain_documents": False, + "order_by": None, + "limit": None, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="include_shared_drives", weight=0.5), + BinaryCritic(critic_field="restrict_to_shared_drive_id", weight=0.5 / 4), + BinaryCritic(critic_field="include_organization_domain_documents", weight=0.5 / 4), + BinaryCritic(critic_field="order_by", weight=0.5 / 4), + BinaryCritic(critic_field="limit", weight=0.5 / 4), + ], + ) + + suite.add_case( + name="how many files are there in all my google drives, including shared ones?", + user_message="how many files are there in all my google drives, including shared ones?", + expected_tool_calls=[ + ExpectedToolCall( + func=get_file_tree_structure, + args={ + "restrict_to_shared_drive_id": None, + "include_shared_drives": True, + "include_organization_domain_documents": False, + "order_by": None, + "limit": None, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="include_shared_drives", weight=0.5), + BinaryCritic(critic_field="restrict_to_shared_drive_id", weight=0.5 / 4), + BinaryCritic(critic_field="include_organization_domain_documents", weight=0.5 / 4), + BinaryCritic(critic_field="order_by", weight=0.5 / 4), + BinaryCritic(critic_field="limit", weight=0.5 / 4), + ], + ) + + return suite diff --git a/toolkits/google_drive/evals/eval_tools_understand_filepicker.py b/toolkits/google_drive/evals/eval_tools_understand_filepicker.py new file mode 100644 index 00000000..be8c4536 --- /dev/null +++ b/toolkits/google_drive/evals/eval_tools_understand_filepicker.py @@ -0,0 +1,70 @@ +from arcade_evals import ( + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +from arcade_google_drive.tools import ( + get_file_tree_structure, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_tool(get_file_tree_structure, "GoogleDrive") + +get_file_tree_structure_history = [ + {"role": "system", "content": "Today is 2025-07-03, Thursday."}, + {"role": "user", "content": "get my file tree structure"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_EnqRPmIx3zrquDA7PuZ5NtK6", + "type": "function", + "function": {"name": "GoogleDrive_GetFileTreeStructure", "arguments": "{}"}, + } + ], + }, + { + "role": "tool", + "content": '{"drives":[]}', + "tool_call_id": "call_EnqRPmIx3zrquDA7PuZ5NtK6", + "name": "GoogleDrive_GetFileTreeStructure", + }, + { + "role": "assistant", + "content": "I could not find any files in your Google Drive. To select and grant access permissions to additional files and folders, use [this Google Drive File Picker link](https://cloud.bosslevel.dev/api/v1/google/drive_picker?config=eyXRoIjogM3NsdnFrMmtqODlldNuLmF0=).", + }, +] + + +@tool_eval() +def tools_understand_filepicker_eval_suite() -> EvalSuite: + """Create an evaluation suite for Google Drive tools.""" + suite = EvalSuite( + name="Google Drive Tools Evaluation", + system_message="You are an AI assistant that can manage Google Drive using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Ensure LLM understands that after using file picker, it should call the tool again", + user_message="ok i followed your suggestion and went to that url. how has it changed?", + expected_tool_calls=[ + ExpectedToolCall( + func=get_file_tree_structure, + args={}, + ) + ], + additional_messages=get_file_tree_structure_history, + ) + return suite diff --git a/toolkits/google_drive/pyproject.toml b/toolkits/google_drive/pyproject.toml new file mode 100644 index 00000000..616c2d77 --- /dev/null +++ b/toolkits/google_drive/pyproject.toml @@ -0,0 +1,62 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_drive" +version = "2.0.0" +description = "Arcade.dev LLM tools for Google Drive" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "google-api-core>=2.19.1,<3.0.0", + "google-api-python-client>=2.137.0,<3.0.0", + "google-auth>=2.32.0,<3.0.0", + "google-auth-httplib2>=0.2.0,<1.0.0", + "googleapis-common-protos>=1.63.2,<2.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.4,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + +[tool.mypy] +files = [ "arcade_google_drive/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_drive",] diff --git a/toolkits/google_drive/tests/__init__.py b/toolkits/google_drive/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_drive/tests/test_drive.py b/toolkits/google_drive/tests/test_drive.py new file mode 100644 index 00000000..fae9aed6 --- /dev/null +++ b/toolkits/google_drive/tests/test_drive.py @@ -0,0 +1,238 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from arcade_google_drive.templates import optional_file_picker_instructions_template +from arcade_google_drive.tools import ( + get_file_tree_structure, +) +from arcade_google_drive.utils import build_drive_service + + +@pytest.fixture +def mock_context(): + context = AsyncMock() + context.authorization.token = "mock_token" # noqa: S105 + context.get_metadata.side_effect = lambda key: { + "client_id": "123456789-abcdefg.apps.googleusercontent.com", + "coordinator_url": "https://coordinator.example.com", + }.get(key.value if hasattr(key, "value") else key) + return context + + +@pytest.fixture +def mock_service(): + with patch( + "arcade_google_drive.tools.drive." + build_drive_service.__name__ + ) as mock_build_service: + yield mock_build_service.return_value + + +@pytest.mark.asyncio +async def test_get_file_tree_structure( + mock_context, mock_service, sample_drive_file_tree_request_responses +): + files_list_sample, drives_get_sample = sample_drive_file_tree_request_responses + + mock_service.files.return_value.list.return_value.execute.side_effect = [files_list_sample] + mock_service.drives.return_value.get.return_value.execute.side_effect = drives_get_sample + + # Mock the generate_google_file_picker_url function + with patch( + "arcade_google_drive.tools.drive.generate_google_file_picker_url" + ) as mock_file_picker: + mock_file_picker.return_value = { + "url": "https://coordinator.example.com/google/drive_picker?config=test_config", + "llm_instructions": optional_file_picker_instructions_template.format( + url="https://coordinator.example.com/google/drive_picker?config=test_config" + ), + } + + result = await get_file_tree_structure(mock_context, include_shared_drives=True) + + expected_file_tree = { + "drives": [ + { + "id": "0AFqcR6obkydtUk9PVA", + "name": "Shared Drive 1", + "children": [ + { + "createdTime": "2025-02-26T00:27:45.526Z", + "id": "1dCOCdPxhTqiB3j3bWrIWM692ZbL8dyjt", + "mimeType": "application/vnd.google-apps.folder", + "modifiedTime": "2025-02-26T00:27:45.526Z", + "name": "shared-1-folder-1", + "children": [ + { + "createdTime": "2025-02-26T00:28:20.571Z", + "id": "19WVyQndQsc0AxxfdrIt5CvDQd6r-BvpqnB8bWZoL7Xk", + "mimeType": "application/vnd.google-apps.document", + "modifiedTime": "2025-02-26T00:28:30.773Z", + "name": "shared-1-folder-1-doc-1", + "size": { + "unit": "bytes", + "value": 1024, + }, + } + ], + }, + { + "createdTime": "2025-02-26T00:27:19.287Z", + "id": "1didt_h-tDjuJ-dmYtHUSyOCPci30K_kSszvg0G3tKBM", + "mimeType": "application/vnd.google-apps.document", + "modifiedTime": "2025-02-26T00:27:26.079Z", + "name": "shared-1-doc-1", + "size": { + "unit": "bytes", + "value": 1024, + }, + }, + ], + }, + { + "name": "My Drive", + "children": [ + { + "createdTime": "2025-01-24T06:34:22.305Z", + "id": "1vB6sv0MD0hYSraYvWU_fcci3GN_-Jf4g-LfyXdG8ZMo", + "mimeType": "application/vnd.google-apps.document", + "modifiedTime": "2025-02-25T21:54:30.632Z", + "name": "The Birth of MX Engineering", + "owners": [ + { + "email": "one_new_tool_everyday@arcade.dev", + "name": "one_new_tool_everyday", + } + ], + "size": { + "unit": "bytes", + "value": 6634, + }, + }, + { + "createdTime": "2025-02-25T17:57:46.036Z", + "id": "1gqioaHG53jPVeJN5gBpHoO-GWtwiJcLo", + "mimeType": "application/vnd.google-apps.folder", + "modifiedTime": "2025-02-25T17:57:46.036Z", + "name": "test folder 1", + "owners": [ + { + "email": "one_new_tool_everyday@arcade.dev", + "name": "one_new_tool_everyday", + } + ], + "children": [ + { + "id": "1J92V9yvVWm_uNHq3CCY4wyG1H9B6iiwO", + "name": "test folder 1.1", + "mimeType": "application/vnd.google-apps.folder", + "createdTime": "2025-02-25T17:58:58.987Z", + "modifiedTime": "2025-02-25T17:58:58.987Z", + "owners": [ + { + "email": "one_new_tool_everyday@arcade.dev", + "name": "one_new_tool_everyday", + } + ], + "children": [ + { + "id": "1wv2dmYo0skJTI59ZIcwH9vm-wt7psMwXTvihuEGeHeI", + "name": "test document 1.1.1", + "mimeType": "application/vnd.google-apps.document", + "createdTime": "2025-02-25T17:59:03.325Z", + "modifiedTime": "2025-02-25T17:59:11.445Z", + "owners": [ + { + "email": "one_new_tool_everyday@arcade.dev", + "name": "one_new_tool_everyday", + } + ], + "size": { + "unit": "bytes", + "value": 1024, + }, + }, + ], + }, + { + "id": "1DSmL7d07kjT6b6L-t4JIT06ElUbZ1q0K6_gEpn_UGZ8", + "name": "test document 1.2", + "mimeType": "application/vnd.google-apps.document", + "createdTime": "2025-02-25T17:58:38.628Z", + "modifiedTime": "2025-02-25T17:58:46.713Z", + "owners": [ + { + "email": "one_new_tool_everyday@arcade.dev", + "name": "one_new_tool_everyday", + } + ], + "size": { + "unit": "bytes", + "value": 1024, + }, + }, + { + "id": "1Fcxz7HsyO2Zyc-5DTD3zBQnaVrZwD29BP9KD9rPnYfE", + "name": "test document 1.1", + "mimeType": "application/vnd.google-apps.document", + "createdTime": "2025-02-25T17:57:53.850Z", + "modifiedTime": "2025-02-25T17:58:28.745Z", + "owners": [ + { + "email": "one_new_tool_everyday@arcade.dev", + "name": "one_new_tool_everyday", + } + ], + "size": { + "unit": "bytes", + "value": 1024, + }, + }, + ], + }, + { + "createdTime": "2025-02-18T20:48:52.786Z", + "id": "16PUe97yGQeOjQgrgd54iCoxzid4SEvu_J33P_ELd5r8", + "mimeType": "application/vnd.google-apps.presentation", + "modifiedTime": "2025-02-19T23:31:20.483Z", + "name": "Hello world presentation", + "owners": [ + { + "email": "john.doe@arcade.dev", + "name": "john.doe", + } + ], + "size": { + "unit": "bytes", + "value": 15774558, + }, + }, + { + "id": "1nG7lSvIyK05N9METPczVJa4iGgE7uoo-A6zpqjpUsDY", + "name": "Shared doc 1", + "mimeType": "application/vnd.google-apps.document", + "createdTime": "2025-02-19T18:51:44.622Z", + "modifiedTime": "2025-02-19T19:30:39.773Z", + "owners": [ + { + "name": "theboss", + "email": "theboss@arcade.dev", + } + ], + "size": { + "unit": "bytes", + "value": 2700, + }, + }, + ], + }, + ], + "file_picker": { + "url": "https://coordinator.example.com/google/drive_picker?config=test_config", + "llm_instructions": optional_file_picker_instructions_template.format( + url="https://coordinator.example.com/google/drive_picker?config=test_config" + ), + }, + } + + assert result == expected_file_tree diff --git a/toolkits/google_finance/.pre-commit-config.yaml b/toolkits/google_finance/.pre-commit-config.yaml new file mode 100644 index 00000000..d83c2356 --- /dev/null +++ b/toolkits/google_finance/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_finance/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_finance/.ruff.toml b/toolkits/google_finance/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/google_finance/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_finance/LICENSE b/toolkits/google_finance/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/google_finance/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_finance/Makefile b/toolkits/google_finance/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_finance/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_finance/arcade_google_finance/__init__.py b/toolkits/google_finance/arcade_google_finance/__init__.py new file mode 100644 index 00000000..69090cec --- /dev/null +++ b/toolkits/google_finance/arcade_google_finance/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_finance.tools import get_stock_historical_data, get_stock_summary + +__all__ = ["get_stock_historical_data", "get_stock_summary"] diff --git a/toolkits/google_finance/arcade_google_finance/enums.py b/toolkits/google_finance/arcade_google_finance/enums.py new file mode 100644 index 00000000..d64df398 --- /dev/null +++ b/toolkits/google_finance/arcade_google_finance/enums.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class GoogleFinanceWindow(Enum): + ONE_DAY = "1D" + FIVE_DAYS = "5D" + ONE_MONTH = "1M" + SIX_MONTHS = "6M" + YEAR_TO_DATE = "YTD" + ONE_YEAR = "1Y" + FIVE_YEARS = "5Y" + MAX = "MAX" diff --git a/toolkits/google_finance/arcade_google_finance/tools/__init__.py b/toolkits/google_finance/arcade_google_finance/tools/__init__.py new file mode 100644 index 00000000..14446019 --- /dev/null +++ b/toolkits/google_finance/arcade_google_finance/tools/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_finance.tools.google_finance import get_stock_historical_data, get_stock_summary + +__all__ = ["get_stock_historical_data", "get_stock_summary"] diff --git a/toolkits/google_finance/arcade_google_finance/tools/google_finance.py b/toolkits/google_finance/arcade_google_finance/tools/google_finance.py new file mode 100644 index 00000000..9a505065 --- /dev/null +++ b/toolkits/google_finance/arcade_google_finance/tools/google_finance.py @@ -0,0 +1,86 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool + +from arcade_google_finance.enums import GoogleFinanceWindow +from arcade_google_finance.utils import call_serpapi, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_stock_summary( + context: ToolContext, + ticker_symbol: Annotated[ + str, + "The stock ticker to get summary for. For example, 'GOOG' is the ticker symbol for Google", + ], + exchange_identifier: Annotated[ + str, + "The exchange identifier. This part indicates the market where the " + "stock is traded. For example, 'NASDAQ', 'NYSE', 'TSE', 'LSE', etc.", + ], +) -> Annotated[dict[str, Any], "Summary of the stock's recent performance"]: + """Retrieve the summary information for a given stock ticker using the Google Finance API. + + Gets the stock's current price as well as price movement from the most recent trading day. + """ + # Prepare the request + query = ( + f"{ticker_symbol.upper()}:{exchange_identifier.upper()}" + if exchange_identifier + else ticker_symbol.upper() + ) + params = prepare_params("google_finance", q=query) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + summary: dict = results.get("summary", {}) + + return summary + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_stock_historical_data( + context: ToolContext, + ticker_symbol: Annotated[ + str, + "The stock ticker to get summary for. For example, 'GOOG' is the ticker symbol for Google", + ], + exchange_identifier: Annotated[ + str, + "The exchange identifier. This part indicates the market where the " + "stock is traded. For example, 'NASDAQ', 'NYSE', 'TSE', 'LSE', etc.", + ], + window: Annotated[ + GoogleFinanceWindow, "Time window for the graph data. Defaults to 1 month" + ] = GoogleFinanceWindow.ONE_MONTH, +) -> Annotated[ + dict[str, Any], + "A stock's price and volume data at a specific time interval over a specified time window", +]: + """Fetch historical stock price data over a specified time window + + Returns a stock's price and volume data over a specified time window + """ + # Prepare the request + query = ( + f"{ticker_symbol.upper()}:{exchange_identifier.upper()}" + if exchange_identifier + else ticker_symbol.upper() + ) + params = prepare_params("google_finance", q=query, window=window.value) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + data = { + "summary": results.get("summary", {}), + "graph": results.get("graph", []), + } + key_events = results.get("key_events") + if key_events: + data["key_events"] = key_events + + return data diff --git a/toolkits/google_finance/arcade_google_finance/utils.py b/toolkits/google_finance/arcade_google_finance/utils.py new file mode 100644 index 00000000..00c0dcba --- /dev/null +++ b/toolkits/google_finance/arcade_google_finance/utils.py @@ -0,0 +1,48 @@ +import re +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) diff --git a/toolkits/google_finance/pyproject.toml b/toolkits/google_finance/pyproject.toml new file mode 100644 index 00000000..0f0f97c1 --- /dev/null +++ b/toolkits/google_finance/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_finance" +version = "2.0.0" +description = "Arcade.dev LLM tools for getting financial data via Google Finance" +requires-python = ">=3.10" +dependencies = [ "arcade-tdk>=2.0.0,<3.0.0", "serpapi>=0.1.5,<1.0.0",] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<9.0.0", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.11.1,<4.0.0", + "pytest-asyncio>=0.24.0,<1.0.0", + "mypy>=1.5.1,<2.0.0", + "pre-commit>=3.4.0,<4.0.0", + "tox>=4.11.1,<5.0.0", + "ruff>=0.7.4,<1.0.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_finance/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_finance",] diff --git a/toolkits/google_flights/.pre-commit-config.yaml b/toolkits/google_flights/.pre-commit-config.yaml new file mode 100644 index 00000000..0e99e3d4 --- /dev/null +++ b/toolkits/google_flights/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_flights/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_flights/.ruff.toml b/toolkits/google_flights/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/google_flights/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_flights/LICENSE b/toolkits/google_flights/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/google_flights/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_flights/Makefile b/toolkits/google_flights/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_flights/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_flights/arcade_google_flights/__init__.py b/toolkits/google_flights/arcade_google_flights/__init__.py new file mode 100644 index 00000000..e4a99dac --- /dev/null +++ b/toolkits/google_flights/arcade_google_flights/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_flights.tools import search_one_way_flights + +__all__ = ["search_one_way_flights"] diff --git a/toolkits/google_flights/arcade_google_flights/enums.py b/toolkits/google_flights/arcade_google_flights/enums.py new file mode 100644 index 00000000..d7a514b4 --- /dev/null +++ b/toolkits/google_flights/arcade_google_flights/enums.py @@ -0,0 +1,53 @@ +from enum import Enum + + +class GoogleFlightsTravelClass(Enum): + ECONOMY = "ECONOMY" + PREMIUM_ECONOMY = "PREMIUM_ECONOMY" + BUSINESS = "BUSINESS" + FIRST = "FIRST" + + def to_api_value(self) -> int: + _map = { + "ECONOMY": 1, + "PREMIUM_ECONOMY": 2, + "BUSINESS": 3, + "FIRST": 4, + } + return _map[self.value] + + +class GoogleFlightsMaxStops(Enum): + ANY = "ANY" + NONSTOP = "NONSTOP" + ONE = "ONE" + TWO = "TWO" + + def to_api_value(self) -> int: + _map = { + "ANY": 0, + "NONSTOP": 1, + "ONE": 2, + "TWO": 3, + } + return _map[self.value] + + +class GoogleFlightsSortBy(Enum): + TOP_FLIGHTS = "TOP_FLIGHTS" + PRICE = "PRICE" + DEPARTURE_TIME = "DEPARTURE_TIME" + ARRIVAL_TIME = "ARRIVAL_TIME" + DURATION = "DURATION" + EMISSIONS = "EMISSIONS" + + def to_api_value(self) -> int: + _map = { + "TOP_FLIGHTS": 1, + "PRICE": 2, + "DEPARTURE_TIME": 3, + "ARRIVAL_TIME": 4, + "DURATION": 5, + "EMISSIONS": 6, + } + return _map[self.value] diff --git a/toolkits/google_flights/arcade_google_flights/tools/__init__.py b/toolkits/google_flights/arcade_google_flights/tools/__init__.py new file mode 100644 index 00000000..de99e742 --- /dev/null +++ b/toolkits/google_flights/arcade_google_flights/tools/__init__.py @@ -0,0 +1,5 @@ +from arcade_google_flights.tools.google_flights import ( + search_one_way_flights, +) + +__all__ = ["search_one_way_flights"] diff --git a/toolkits/google_flights/arcade_google_flights/tools/google_flights.py b/toolkits/google_flights/arcade_google_flights/tools/google_flights.py new file mode 100644 index 00000000..1609a004 --- /dev/null +++ b/toolkits/google_flights/arcade_google_flights/tools/google_flights.py @@ -0,0 +1,61 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool + +from arcade_google_flights.enums import ( + GoogleFlightsMaxStops, + GoogleFlightsSortBy, + GoogleFlightsTravelClass, +) +from arcade_google_flights.utils import call_serpapi, parse_flight_results, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_one_way_flights( + context: ToolContext, + departure_airport_code: Annotated[ + str, "The departure airport code. An uppercase 3-letter code" + ], + arrival_airport_code: Annotated[str, "The arrival airport code. An uppercase 3-letter code"], + outbound_date: Annotated[str, "Flight departure date in YYYY-MM-DD format"], + currency_code: Annotated[ + str | None, "Currency of the returned prices. Defaults to 'USD'" + ] = "USD", + travel_class: Annotated[ + GoogleFlightsTravelClass, + "Travel class of the flight. Defaults to 'ECONOMY'", + ] = GoogleFlightsTravelClass.ECONOMY, + num_adults: Annotated[int | None, "Number of adult passengers. Defaults to 1"] = 1, + num_children: Annotated[int | None, "Number of child passengers. Defaults to 0"] = 0, + max_stops: Annotated[ + GoogleFlightsMaxStops, + "Maximum number of stops (layovers) for the flight. Defaults to any number of stops", + ] = GoogleFlightsMaxStops.ANY, + sort_by: Annotated[ + GoogleFlightsSortBy, + "The sorting order of the results. Defaults to TOP_FLIGHTS.", + ] = GoogleFlightsSortBy.TOP_FLIGHTS, +) -> Annotated[dict[str, Any], "Flight search results from the Google Flights API"]: + """Retrieve flight search results for a one-way flight using Google Flights""" + params = prepare_params( + "google_flights", + departure_id=departure_airport_code, + arrival_id=arrival_airport_code, + outbound_date=outbound_date, + currency=currency_code, + travel_class=travel_class.to_api_value(), + adults=num_adults, + children=num_children, + stops=max_stops.to_api_value(), + sort_by=sort_by.to_api_value(), + type=2, # indicates one-way + deep_search=True, # Same search depth as the Google Flights page in the browser + ) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + flights = parse_flight_results(results) + + return flights diff --git a/toolkits/google_flights/arcade_google_flights/utils.py b/toolkits/google_flights/arcade_google_flights/utils.py new file mode 100644 index 00000000..f314248b --- /dev/null +++ b/toolkits/google_flights/arcade_google_flights/utils.py @@ -0,0 +1,68 @@ +import re +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) + + +def parse_flight_results(results: dict[str, Any]) -> dict[str, Any]: + """Parse the flight results from the Google Flights API + + Note: Best flights is not always returned from the API. + """ + flight_data = {} + flights = [] + + if "best_flights" in results: + flights.extend(results["best_flights"]) + if "other_flights" in results: + flights.extend(results["other_flights"]) + if "price_insights" in results: + flight_data["price_insights"] = results["price_insights"] + + flight_data["flights"] = flights + + return flight_data diff --git a/toolkits/google_flights/pyproject.toml b/toolkits/google_flights/pyproject.toml new file mode 100644 index 00000000..c99a7199 --- /dev/null +++ b/toolkits/google_flights/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_flights" +version = "2.0.0" +description = "Arcade.dev LLM tools for getting flights via Google Flights" +requires-python = ">=3.10" +dependencies = [ "arcade-tdk>=2.0.0,<3.0.0", "serpapi>=0.1.5,<1.0.0",] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<9.0.0", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.11.1,<4.0.0", + "pytest-asyncio>=0.24.0,<1.0.0", + "mypy>=1.5.1,<2.0.0", + "pre-commit>=3.4.0,<4.0.0", + "tox>=4.11.1,<5.0.0", + "ruff>=0.7.4,<1.0.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_flights/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_flights",] diff --git a/toolkits/google_hotels/.pre-commit-config.yaml b/toolkits/google_hotels/.pre-commit-config.yaml new file mode 100644 index 00000000..95883156 --- /dev/null +++ b/toolkits/google_hotels/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_hotels/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_hotels/.ruff.toml b/toolkits/google_hotels/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/google_hotels/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_hotels/LICENSE b/toolkits/google_hotels/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/google_hotels/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_hotels/Makefile b/toolkits/google_hotels/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_hotels/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_hotels/arcade_google_hotels/__init__.py b/toolkits/google_hotels/arcade_google_hotels/__init__.py new file mode 100644 index 00000000..fecdfddb --- /dev/null +++ b/toolkits/google_hotels/arcade_google_hotels/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_hotels.tools import search_hotels + +__all__ = ["search_hotels"] diff --git a/toolkits/google_hotels/arcade_google_hotels/enums.py b/toolkits/google_hotels/arcade_google_hotels/enums.py new file mode 100644 index 00000000..971bd04f --- /dev/null +++ b/toolkits/google_hotels/arcade_google_hotels/enums.py @@ -0,0 +1,17 @@ +from enum import Enum + + +class GoogleHotelsSortBy(Enum): + RELEVANCE = "RELEVANCE" + LOWEST_PRICE = "LOWEST_PRICE" + HIGHEST_RATING = "HIGHEST_RATING" + MOST_REVIEWED = "MOST_REVIEWED" + + def to_api_value(self) -> int | None: + _map = { + "RELEVANCE": None, + "LOWEST_PRICE": 3, + "HIGHEST_RATING": 8, + "MOST_REVIEWED": 13, + } + return _map[self.value] diff --git a/toolkits/google_hotels/arcade_google_hotels/tools/__init__.py b/toolkits/google_hotels/arcade_google_hotels/tools/__init__.py new file mode 100644 index 00000000..dc3bcb98 --- /dev/null +++ b/toolkits/google_hotels/arcade_google_hotels/tools/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_hotels.tools.google_hotels import search_hotels + +__all__ = ["search_hotels"] diff --git a/toolkits/google_hotels/arcade_google_hotels/tools/google_hotels.py b/toolkits/google_hotels/arcade_google_hotels/tools/google_hotels.py new file mode 100644 index 00000000..8cef7661 --- /dev/null +++ b/toolkits/google_hotels/arcade_google_hotels/tools/google_hotels.py @@ -0,0 +1,58 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool + +from arcade_google_hotels.enums import GoogleHotelsSortBy +from arcade_google_hotels.utils import call_serpapi, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_hotels( + context: ToolContext, + location: Annotated[str, "Location to search for hotels, e.g., a city name, a state, etc."], + check_in_date: Annotated[str, "Check-in date in YYYY-MM-DD format"], + check_out_date: Annotated[str, "Check-out date in YYYY-MM-DD format"], + query: Annotated[ + str | None, "Anything that would be used in a regular Google Hotels search" + ] = None, + currency: Annotated[str | None, "Currency code for prices. Defaults to 'USD'"] = "USD", + min_price: Annotated[int | None, "Minimum price per night. Defaults to no minimum"] = None, + max_price: Annotated[int | None, "Maximum price per night. Defaults to no maximum"] = None, + num_adults: Annotated[int | None, "Number of adults per room. Defaults to 2"] = 2, + num_children: Annotated[int | None, "Number of children per room. Defaults to 0"] = 0, + sort_by: Annotated[ + GoogleHotelsSortBy, "The sorting order of the results. Defaults to RELEVANCE" + ] = GoogleHotelsSortBy.RELEVANCE, + num_results: Annotated[ + int | None, "Maximum number of results to return. Defaults to 5. Max 20" + ] = 5, +) -> Annotated[dict[str, Any], "Hotel search results from the Google Hotels API"]: + """Retrieve hotel search results using the Google Hotels API.""" + # Prepare the request + params = prepare_params( + "google_hotels", + q=f"{query}, {location}" if query else location, + check_in_date=check_in_date, + check_out_date=check_out_date, + currency=currency, + min_price=min_price, + max_price=max_price, + adults=num_adults, + children=num_children, + sort_by=sort_by.to_api_value(), + ) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + properties = results.get("properties", [])[:num_results] + + # Remove unwanted fields from each property + for hotel in properties: + hotel.pop("images", None) + hotel.pop("extracted_hotel_class", None) + hotel.pop("reviews_breakdown", None) + hotel.pop("serpapi_property_details_link", None) + + return {"properties": properties} diff --git a/toolkits/google_hotels/arcade_google_hotels/utils.py b/toolkits/google_hotels/arcade_google_hotels/utils.py new file mode 100644 index 00000000..00c0dcba --- /dev/null +++ b/toolkits/google_hotels/arcade_google_hotels/utils.py @@ -0,0 +1,48 @@ +import re +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) diff --git a/toolkits/google_hotels/pyproject.toml b/toolkits/google_hotels/pyproject.toml new file mode 100644 index 00000000..20ebf5e4 --- /dev/null +++ b/toolkits/google_hotels/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_hotels" +version = "2.0.0" +description = "Arcade.dev LLM tools for getting Hotel information via Google Hotels" +requires-python = ">=3.10" +dependencies = [ "arcade-tdk>=2.0.0,<3.0.0", "serpapi>=0.1.5,<1.0.0",] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<9.0.0", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.11.1,<4.0.0", + "pytest-asyncio>=0.24.0,<1.0.0", + "mypy>=1.5.1,<2.0.0", + "pre-commit>=3.4.0,<4.0.0", + "tox>=4.11.1,<5.0.0", + "ruff>=0.7.4,<1.0.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_hotels/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_hotels",] diff --git a/toolkits/google_jobs/.pre-commit-config.yaml b/toolkits/google_jobs/.pre-commit-config.yaml new file mode 100644 index 00000000..e347e6ee --- /dev/null +++ b/toolkits/google_jobs/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_jobs/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_jobs/.ruff.toml b/toolkits/google_jobs/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/google_jobs/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_jobs/LICENSE b/toolkits/google_jobs/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/google_jobs/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_jobs/Makefile b/toolkits/google_jobs/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_jobs/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_jobs/arcade_google_jobs/__init__.py b/toolkits/google_jobs/arcade_google_jobs/__init__.py new file mode 100644 index 00000000..82988ada --- /dev/null +++ b/toolkits/google_jobs/arcade_google_jobs/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_jobs.tools import search_jobs + +__all__ = ["search_jobs"] diff --git a/toolkits/google_jobs/arcade_google_jobs/constants.py b/toolkits/google_jobs/arcade_google_jobs/constants.py new file mode 100644 index 00000000..6101a22a --- /dev/null +++ b/toolkits/google_jobs/arcade_google_jobs/constants.py @@ -0,0 +1,5 @@ +import os + +DEFAULT_GOOGLE_LANGUAGE = os.getenv("ARCADE_GOOGLE_LANGUAGE", "en") + +DEFAULT_GOOGLE_JOBS_LANGUAGE = os.getenv("ARCADE_GOOGLE_JOBS_LANGUAGE", DEFAULT_GOOGLE_LANGUAGE) diff --git a/toolkits/google_jobs/arcade_google_jobs/enums.py b/toolkits/google_jobs/arcade_google_jobs/enums.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_jobs/arcade_google_jobs/exceptions.py b/toolkits/google_jobs/arcade_google_jobs/exceptions.py new file mode 100644 index 00000000..1deb0eed --- /dev/null +++ b/toolkits/google_jobs/arcade_google_jobs/exceptions.py @@ -0,0 +1,17 @@ +import json + +from arcade_tdk.errors import RetryableToolError + +from arcade_google_jobs.google_data import LANGUAGE_CODES + + +class GoogleRetryableError(RetryableToolError): + pass + + +class LanguageNotFoundError(GoogleRetryableError): + def __init__(self, language: str | None) -> None: + valid_languages = json.dumps(LANGUAGE_CODES, default=str) + message = f"Language not found: '{language}'." + additional_message = f"Valid languages are: {valid_languages}" + super().__init__(message, additional_prompt_content=additional_message) diff --git a/toolkits/google_jobs/arcade_google_jobs/google_data.py b/toolkits/google_jobs/arcade_google_jobs/google_data.py new file mode 100644 index 00000000..7ed7dff6 --- /dev/null +++ b/toolkits/google_jobs/arcade_google_jobs/google_data.py @@ -0,0 +1,33 @@ +LANGUAGE_CODES = { + "ar": "Arabic", + "bn": "Bengali", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "fi": "Finnish", + "fr": "French", + "hi": "Hindi", + "hu": "Hungarian", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "ko": "Korean", + "nl": "Dutch", + "ms": "Malay", + "no": "Norwegian", + "pcm": "Nigerian Pidgin", + "pl": "Polish", + "pt": "Portuguese", + "pt-br": "Portuguese (Brazil)", + "pt-pt": "Portuguese (Portugal)", + "ru": "Russian", + "sv": "Swedish", + "tl": "Filipino", + "tr": "Turkish", + "uk": "Ukrainian", + "zh": "Chinese", + "zh-cn": "Chinese (Simplified)", + "zh-tw": "Chinese (Traditional)", +} diff --git a/toolkits/google_jobs/arcade_google_jobs/tools/__init__.py b/toolkits/google_jobs/arcade_google_jobs/tools/__init__.py new file mode 100644 index 00000000..e0acd323 --- /dev/null +++ b/toolkits/google_jobs/arcade_google_jobs/tools/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_jobs.tools.google_jobs import search_jobs + +__all__ = ["search_jobs"] diff --git a/toolkits/google_jobs/arcade_google_jobs/tools/google_jobs.py b/toolkits/google_jobs/arcade_google_jobs/tools/google_jobs.py new file mode 100644 index 00000000..4cde1b15 --- /dev/null +++ b/toolkits/google_jobs/arcade_google_jobs/tools/google_jobs.py @@ -0,0 +1,65 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool + +from arcade_google_jobs.constants import DEFAULT_GOOGLE_JOBS_LANGUAGE +from arcade_google_jobs.exceptions import LanguageNotFoundError +from arcade_google_jobs.google_data import LANGUAGE_CODES +from arcade_google_jobs.utils import call_serpapi, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_jobs( + context: ToolContext, + query: Annotated[ + str, + "Search query. Provide a job title, company name, and/or any keywords in general " + "representing what kind of jobs the user is looking for. E.g. 'software engineer' " + "or 'data analyst at Apple'.", + ], + location: Annotated[ + str | None, + "Location to search for jobs. E.g. 'United States' or 'New York, NY'. Defaults to None.", + ] = None, + language: Annotated[ + str, + "2-character language code to use in the Google Jobs search. " + f"E.g. 'en' for English. Defaults to '{DEFAULT_GOOGLE_JOBS_LANGUAGE}'.", + ] = DEFAULT_GOOGLE_JOBS_LANGUAGE, + limit: Annotated[ + int, + "Maximum number of results to retrieve. Defaults to 10 (max supported by the API).", + ] = 10, + next_page_token: Annotated[ + str | None, + "Next page token to paginate results. Defaults to None (start from the first page).", + ] = None, +) -> Annotated[dict, "Google Jobs results"]: + """Search Google Jobs using SerpAPI.""" + if language not in LANGUAGE_CODES: + raise LanguageNotFoundError(language) + + params = prepare_params( + engine="google_jobs", + q=query, + hl=language, + ) + + if location: + params["location"] = location + + if next_page_token: + params["next_page_token"] = next_page_token + + results = call_serpapi(context, params) + jobs_results = results.get("jobs_results", []) + + try: + next_page_token = results["serpapi_pagination"]["next_page_token"] + except KeyError: + next_page_token = None + + return { + "jobs": jobs_results[:limit], + "next_page_token": next_page_token, + } diff --git a/toolkits/google_jobs/arcade_google_jobs/utils.py b/toolkits/google_jobs/arcade_google_jobs/utils.py new file mode 100644 index 00000000..00c0dcba --- /dev/null +++ b/toolkits/google_jobs/arcade_google_jobs/utils.py @@ -0,0 +1,48 @@ +import re +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) diff --git a/toolkits/google_jobs/evals/eval_google_jobs.py b/toolkits/google_jobs/evals/eval_google_jobs.py new file mode 100644 index 00000000..d14a3302 --- /dev/null +++ b/toolkits/google_jobs/evals/eval_google_jobs.py @@ -0,0 +1,157 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + NoneCritic, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_google_jobs +from arcade_google_jobs.constants import DEFAULT_GOOGLE_JOBS_LANGUAGE +from arcade_google_jobs.tools import search_jobs + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + +catalog = ToolCatalog() +# Register the Google Jobs tool +catalog.add_module(arcade_google_jobs) + + +@tool_eval() +def google_jobs_eval_suite() -> EvalSuite: + """Create an evaluation suite for the Google Jobs tool.""" + suite = EvalSuite( + name="Google Jobs Tool Evaluation", + system_message="You are an AI assistant that can perform job searches using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Search for 'backend engineer' jobs", + user_message="Search for 'backend engineer' jobs", + expected_tool_calls=[ + ExpectedToolCall( + func=search_jobs, + args={ + "query": "backend engineer", + "location": None, + "language": DEFAULT_GOOGLE_JOBS_LANGUAGE, + "limit": 10, + "next_page_token": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=0.5), + NoneCritic(critic_field="location", weight=0.1), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="limit", weight=0.1), + NoneCritic(critic_field="next_page_token", weight=0.1), + ], + ) + + suite.add_case( + name="Search for 'senior backend engineer' jobs that are part-time", + user_message="Search for senior backend engineer jobs that are part-time", + expected_tool_calls=[ + ExpectedToolCall( + func=search_jobs, + args={ + "query": "part-time senior backend engineer", + "location": None, + "language": DEFAULT_GOOGLE_JOBS_LANGUAGE, + "limit": 10, + "next_page_token": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=0.5), + NoneCritic(critic_field="location", weight=0.1), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="limit", weight=0.1), + NoneCritic(critic_field="next_page_token", weight=0.1), + ], + ) + + suite.add_case( + name="Search for 'backend engineer' jobs in San Francisco", + user_message="Search for 'backend engineer' jobs in San Francisco", + expected_tool_calls=[ + ExpectedToolCall( + func=search_jobs, + args={ + "query": "backend engineer", + "location": "San Francisco", + "language": DEFAULT_GOOGLE_JOBS_LANGUAGE, + "limit": 10, + "next_page_token": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=0.35), + SimilarityCritic(critic_field="location", weight=0.35), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="limit", weight=0.1), + NoneCritic(critic_field="next_page_token", weight=0.1), + ], + ) + + suite.add_case( + name="Get the first 3 jobs for 'backend engineer' in San Francisco", + user_message="Get the first 3 jobs for 'backend engineer' in San Francisco", + expected_tool_calls=[ + ExpectedToolCall( + func=search_jobs, + args={ + "query": "backend engineer", + "location": "San Francisco", + "language": DEFAULT_GOOGLE_JOBS_LANGUAGE, + "limit": 3, + "next_page_token": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=0.25), + SimilarityCritic(critic_field="location", weight=0.25), + BinaryCritic(critic_field="language", weight=0.125), + BinaryCritic(critic_field="limit", weight=0.25), + NoneCritic(critic_field="next_page_token", weight=0.125), + ], + ) + + suite.add_case( + name="Search for 'engenheiro de software' jobs in Brazil, return results in Portuguese", + user_message="Search for 'engenheiro de software' jobs in Brazil, return results in Portuguese", + expected_tool_calls=[ + ExpectedToolCall( + func=search_jobs, + args={ + "query": "engenheiro de software", + "location": "Brazil", + "language": "pt", + "limit": 10, + "next_page_token": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=0.25), + SimilarityCritic(critic_field="location", weight=0.125), + BinaryCritic(critic_field="language", weight=0.25), + BinaryCritic(critic_field="limit", weight=0.125), + NoneCritic(critic_field="next_page_token", weight=0.125), + ], + ) + + return suite diff --git a/toolkits/google_jobs/pyproject.toml b/toolkits/google_jobs/pyproject.toml new file mode 100644 index 00000000..ca5152ae --- /dev/null +++ b/toolkits/google_jobs/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_jobs" +version = "2.0.0" +description = "Arcade.dev LLM tools for getting job postings via Google Jobs" +requires-python = ">=3.10" +dependencies = [ "arcade-tdk>=2.0.0,<3.0.0", "serpapi>=0.1.5,<1.0.0",] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<9.0.0", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.11.1,<4.0.0", + "pytest-asyncio>=0.24.0,<1.0.0", + "mypy>=1.5.1,<2.0.0", + "pre-commit>=3.4.0,<4.0.0", + "tox>=4.11.1,<5.0.0", + "ruff>=0.7.4,<1.0.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_jobs/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_jobs",] diff --git a/toolkits/google_jobs/tests/__init__.py b/toolkits/google_jobs/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_jobs/tests/test_google_jobs.py b/toolkits/google_jobs/tests/test_google_jobs.py new file mode 100644 index 00000000..ad0848dd --- /dev/null +++ b/toolkits/google_jobs/tests/test_google_jobs.py @@ -0,0 +1,90 @@ +from unittest.mock import patch + +import pytest +from arcade_tdk import ToolContext, ToolSecretItem + +from arcade_google_jobs.exceptions import LanguageNotFoundError +from arcade_google_jobs.tools.google_jobs import search_jobs + + +@pytest.fixture +def mock_context(): + return ToolContext(secrets=[ToolSecretItem(key="serp_api_key", value="fake_api_key")]) + + +@pytest.mark.asyncio +@patch("arcade_google_jobs.utils.SerpClient") +async def test_search_jobs_success(mock_serp_client, mock_context): + mock_serp_client_instance = mock_serp_client.return_value + mock_serp_client_instance.search().as_dict.return_value = { + "jobs_results": [ + {"title": "Job 1", "link": "http://example.com/1"}, + {"title": "Job 2", "link": "http://example.com/2"}, + ] + } + + result = await search_jobs(mock_context, "engenheiro de software", "Brazil", "pt", 10, None) + assert result == { + "jobs": [ + {"title": "Job 1", "link": "http://example.com/1"}, + {"title": "Job 2", "link": "http://example.com/2"}, + ], + "next_page_token": None, + } + + +@pytest.mark.asyncio +@patch("arcade_google_jobs.utils.SerpClient") +async def test_search_jobs_success_with_custom_language_and_location( + mock_serp_client, mock_context +): + mock_serp_client_instance = mock_serp_client.return_value + mock_serp_client_instance.search().as_dict.return_value = { + "jobs_results": [ + {"title": "Job 1", "link": "http://example.com/1"}, + {"title": "Job 2", "link": "http://example.com/2"}, + ] + } + + result = await search_jobs( + context=mock_context, + query="engenheiro de software", + location="Brazil", + language="pt", + limit=10, + next_page_token=None, + ) + + mock_serp_client_instance.search.assert_called_with({ + "engine": "google_jobs", + "q": "engenheiro de software", + "hl": "pt", + "location": "Brazil", + }) + + assert result == { + "jobs": [ + {"title": "Job 1", "link": "http://example.com/1"}, + {"title": "Job 2", "link": "http://example.com/2"}, + ], + "next_page_token": None, + } + + +@pytest.mark.asyncio +@patch("arcade_google_jobs.utils.SerpClient") +async def test_search_jobs_language_not_found_error(mock_serp_client, mock_context): + mock_serp_client_instance = mock_serp_client.return_value + mock_serp_client_instance.search().as_dict.return_value = { + "jobs_results": [ + {"title": "Job 1", "link": "http://example.com/1"}, + {"title": "Job 2", "link": "http://example.com/2"}, + ] + } + + with pytest.raises(LanguageNotFoundError): + await search_jobs( + context=mock_context, + query="backend engineer", + language="invalid_language", + ) diff --git a/toolkits/google_maps/.pre-commit-config.yaml b/toolkits/google_maps/.pre-commit-config.yaml new file mode 100644 index 00000000..7005e27d --- /dev/null +++ b/toolkits/google_maps/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_maps/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_maps/.ruff.toml b/toolkits/google_maps/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/google_maps/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_maps/LICENSE b/toolkits/google_maps/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/google_maps/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_maps/Makefile b/toolkits/google_maps/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_maps/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_maps/arcade_google_maps/__init__.py b/toolkits/google_maps/arcade_google_maps/__init__.py new file mode 100644 index 00000000..f116016b --- /dev/null +++ b/toolkits/google_maps/arcade_google_maps/__init__.py @@ -0,0 +1,6 @@ +from arcade_google_maps.tools import ( + get_directions_between_addresses, + get_directions_between_coordinates, +) + +__all__ = ["get_directions_between_addresses", "get_directions_between_coordinates"] diff --git a/toolkits/google_maps/arcade_google_maps/constants.py b/toolkits/google_maps/arcade_google_maps/constants.py new file mode 100644 index 00000000..e5877df2 --- /dev/null +++ b/toolkits/google_maps/arcade_google_maps/constants.py @@ -0,0 +1,14 @@ +import os + +from arcade_google_maps.enums import GoogleMapsDistanceUnit, GoogleMapsTravelMode + +DEFAULT_GOOGLE_LANGUAGE = os.getenv("ARCADE_GOOGLE_LANGUAGE", "en") + +DEFAULT_GOOGLE_MAPS_LANGUAGE = os.getenv("ARCADE_GOOGLE_MAPS_LANGUAGE", DEFAULT_GOOGLE_LANGUAGE) +DEFAULT_GOOGLE_MAPS_COUNTRY = os.getenv("ARCADE_GOOGLE_MAPS_COUNTRY") +DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT = GoogleMapsDistanceUnit( + os.getenv("ARCADE_GOOGLE_MAPS_DISTANCE_UNIT", GoogleMapsDistanceUnit.KM.value) +) +DEFAULT_GOOGLE_MAPS_TRAVEL_MODE = GoogleMapsTravelMode( + os.getenv("ARCADE_GOOGLE_MAPS_TRAVEL_MODE", GoogleMapsTravelMode.BEST.value) +) diff --git a/toolkits/google_maps/arcade_google_maps/enums.py b/toolkits/google_maps/arcade_google_maps/enums.py new file mode 100644 index 00000000..98e22613 --- /dev/null +++ b/toolkits/google_maps/arcade_google_maps/enums.py @@ -0,0 +1,35 @@ +from enum import Enum + + +class GoogleMapsTravelMode(Enum): + BEST = "best" + DRIVING = "driving" + MOTORCYCLE = "motorcycle" + PUBLIC_TRANSPORTATION = "public_transportation" + WALKING = "walking" + BICYCLE = "bicycle" + FLIGHT = "flight" + + def to_api_value(self) -> int: + _map = { + str(self.BEST): 6, + str(self.DRIVING): 0, + str(self.MOTORCYCLE): 9, + str(self.PUBLIC_TRANSPORTATION): 3, + str(self.WALKING): 2, + str(self.BICYCLE): 1, + str(self.FLIGHT): 4, + } + return _map[str(self)] + + +class GoogleMapsDistanceUnit(Enum): + KM = "km" + MILES = "mi" + + def to_api_value(self) -> int: + _map = { + str(self.KM): 0, + str(self.MILES): 1, + } + return _map[str(self)] diff --git a/toolkits/google_maps/arcade_google_maps/exceptions.py b/toolkits/google_maps/arcade_google_maps/exceptions.py new file mode 100644 index 00000000..450b70c6 --- /dev/null +++ b/toolkits/google_maps/arcade_google_maps/exceptions.py @@ -0,0 +1,25 @@ +import json + +from arcade_tdk.errors import RetryableToolError + +from arcade_google_maps.google_data import COUNTRY_CODES, LANGUAGE_CODES + + +class GoogleRetryableError(RetryableToolError): + pass + + +class CountryNotFoundError(GoogleRetryableError): + def __init__(self, country: str | None) -> None: + valid_countries = json.dumps(COUNTRY_CODES, default=str) + message = f"Country not found: '{country}'." + additional_message = f"Valid countries are: {valid_countries}" + super().__init__(message, additional_prompt_content=additional_message) + + +class LanguageNotFoundError(GoogleRetryableError): + def __init__(self, language: str | None) -> None: + valid_languages = json.dumps(LANGUAGE_CODES, default=str) + message = f"Language not found: '{language}'." + additional_message = f"Valid languages are: {valid_languages}" + super().__init__(message, additional_prompt_content=additional_message) diff --git a/toolkits/google_maps/arcade_google_maps/google_data.py b/toolkits/google_maps/arcade_google_maps/google_data.py new file mode 100644 index 00000000..789e3183 --- /dev/null +++ b/toolkits/google_maps/arcade_google_maps/google_data.py @@ -0,0 +1,281 @@ +COUNTRY_CODES = { + "af": "Afghanistan", + "al": "Albania", + "dz": "Algeria", + "as": "American Samoa", + "ad": "Andorra", + "ao": "Angola", + "ai": "Anguilla", + "aq": "Antarctica", + "ag": "Antigua and Barbuda", + "ar": "Argentina", + "am": "Armenia", + "aw": "Aruba", + "au": "Australia", + "at": "Austria", + "az": "Azerbaijan", + "bs": "Bahamas", + "bh": "Bahrain", + "bd": "Bangladesh", + "bb": "Barbados", + "by": "Belarus", + "be": "Belgium", + "bz": "Belize", + "bj": "Benin", + "bm": "Bermuda", + "bt": "Bhutan", + "bo": "Bolivia", + "ba": "Bosnia and Herzegovina", + "bw": "Botswana", + "bv": "Bouvet Island", + "br": "Brazil", + "io": "British Indian Ocean Territory", + "bn": "Brunei Darussalam", + "bg": "Bulgaria", + "bf": "Burkina Faso", + "bi": "Burundi", + "kh": "Cambodia", + "cm": "Cameroon", + "ca": "Canada", + "cv": "Cape Verde", + "ky": "Cayman Islands", + "cf": "Central African Republic", + "td": "Chad", + "cl": "Chile", + "cn": "China", + "cx": "Christmas Island", + "cc": "Cocos (Keeling) Islands", + "co": "Colombia", + "km": "Comoros", + "cg": "Congo", + "cd": "Congo, the Democratic Republic of the", + "ck": "Cook Islands", + "cr": "Costa Rica", + "ci": "Cote D'ivoire", + "hr": "Croatia", + "cu": "Cuba", + "cy": "Cyprus", + "cz": "Czech Republic", + "dk": "Denmark", + "dj": "Djibouti", + "dm": "Dominica", + "do": "Dominican Republic", + "ec": "Ecuador", + "eg": "Egypt", + "sv": "El Salvador", + "gq": "Equatorial Guinea", + "er": "Eritrea", + "ee": "Estonia", + "et": "Ethiopia", + "fk": "Falkland Islands (Malvinas)", + "fo": "Faroe Islands", + "fj": "Fiji", + "fi": "Finland", + "fr": "France", + "gf": "French Guiana", + "pf": "French Polynesia", + "tf": "French Southern Territories", + "ga": "Gabon", + "gm": "Gambia", + "ge": "Georgia", + "de": "Germany", + "gh": "Ghana", + "gi": "Gibraltar", + "gr": "Greece", + "gl": "Greenland", + "gd": "Grenada", + "gp": "Guadeloupe", + "gu": "Guam", + "gt": "Guatemala", + "gg": "Guernsey", + "gn": "Guinea", + "gw": "Guinea-Bissau", + "gy": "Guyana", + "ht": "Haiti", + "hm": "Heard Island and Mcdonald Islands", + "va": "Holy See (Vatican City State)", + "hn": "Honduras", + "hk": "Hong Kong", + "hu": "Hungary", + "is": "Iceland", + "in": "India", + "id": "Indonesia", + "ir": "Iran, Islamic Republic of", + "iq": "Iraq", + "ie": "Ireland", + "im": "Isle of Man", + "il": "Israel", + "it": "Italy", + "je": "Jersey", + "jm": "Jamaica", + "jp": "Japan", + "jo": "Jordan", + "kz": "Kazakhstan", + "ke": "Kenya", + "ki": "Kiribati", + "kp": "Korea, Democratic People's Republic of", + "kr": "Korea, Republic of", + "kw": "Kuwait", + "kg": "Kyrgyzstan", + "la": "Lao People's Democratic Republic", + "lv": "Latvia", + "lb": "Lebanon", + "ls": "Lesotho", + "lr": "Liberia", + "ly": "Libyan Arab Jamahiriya", + "li": "Liechtenstein", + "lt": "Lithuania", + "lu": "Luxembourg", + "mo": "Macao", + "mk": "Macedonia, the Former Yugosalv Republic of", + "mg": "Madagascar", + "mw": "Malawi", + "my": "Malaysia", + "mv": "Maldives", + "ml": "Mali", + "mt": "Malta", + "mh": "Marshall Islands", + "mq": "Martinique", + "mr": "Mauritania", + "mu": "Mauritius", + "yt": "Mayotte", + "mx": "Mexico", + "fm": "Micronesia, Federated States of", + "md": "Moldova, Republic of", + "mc": "Monaco", + "mn": "Mongolia", + "me": "Montenegro", + "ms": "Montserrat", + "ma": "Morocco", + "mz": "Mozambique", + "mm": "Myanmar", + "na": "Namibia", + "nr": "Nauru", + "np": "Nepal", + "nl": "Netherlands", + "an": "Netherlands Antilles", + "nc": "New Caledonia", + "nz": "New Zealand", + "ni": "Nicaragua", + "ne": "Niger", + "ng": "Nigeria", + "nu": "Niue", + "nf": "Norfolk Island", + "mp": "Northern Mariana Islands", + "no": "Norway", + "om": "Oman", + "pk": "Pakistan", + "pw": "Palau", + "ps": "Palestinian Territory, Occupied", + "pa": "Panama", + "pg": "Papua New Guinea", + "py": "Paraguay", + "pe": "Peru", + "ph": "Philippines", + "pn": "Pitcairn", + "pl": "Poland", + "pt": "Portugal", + "pr": "Puerto Rico", + "qa": "Qatar", + "re": "Reunion", + "ro": "Romania", + "ru": "Russian Federation", + "rw": "Rwanda", + "sh": "Saint Helena", + "kn": "Saint Kitts and Nevis", + "lc": "Saint Lucia", + "pm": "Saint Pierre and Miquelon", + "vc": "Saint Vincent and the Grenadines", + "ws": "Samoa", + "sm": "San Marino", + "st": "Sao Tome and Principe", + "sa": "Saudi Arabia", + "sn": "Senegal", + "rs": "Serbia", + "sc": "Seychelles", + "sl": "Sierra Leone", + "sg": "Singapore", + "sk": "Slovakia", + "si": "Slovenia", + "sb": "Solomon Islands", + "so": "Somalia", + "za": "South Africa", + "gs": "South Georgia and the South Sandwich Islands", + "es": "Spain", + "lk": "Sri Lanka", + "sd": "Sudan", + "sr": "Suriname", + "sj": "Svalbard and Jan Mayen", + "sz": "Swaziland", + "se": "Sweden", + "ch": "Switzerland", + "sy": "Syrian Arab Republic", + "tw": "Taiwan, Province of China", + "tj": "Tajikistan", + "tz": "Tanzania, United Republic of", + "th": "Thailand", + "tl": "Timor-Leste", + "tg": "Togo", + "tk": "Tokelau", + "to": "Tonga", + "tt": "Trinidad and Tobago", + "tn": "Tunisia", + "tr": "Turkiye", + "tm": "Turkmenistan", + "tc": "Turks and Caicos Islands", + "tv": "Tuvalu", + "ug": "Uganda", + "ua": "Ukraine", + "ae": "United Arab Emirates", + "uk": "United Kingdom", + "gb": "United Kingdom", + "us": "United States", + "um": "United States Minor Outlying Islands", + "uy": "Uruguay", + "uz": "Uzbekistan", + "vu": "Vanuatu", + "ve": "Venezuela", + "vn": "Viet Nam", + "vg": "Virgin Islands, British", + "vi": "Virgin Islands, U.S.", + "wf": "Wallis and Futuna", + "eh": "Western Sahara", + "ye": "Yemen", + "zm": "Zambia", + "zw": "Zimbabwe", +} + + +LANGUAGE_CODES = { + "ar": "Arabic", + "bn": "Bengali", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "fi": "Finnish", + "fr": "French", + "hi": "Hindi", + "hu": "Hungarian", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "ko": "Korean", + "nl": "Dutch", + "ms": "Malay", + "no": "Norwegian", + "pcm": "Nigerian Pidgin", + "pl": "Polish", + "pt": "Portuguese", + "pt-br": "Portuguese (Brazil)", + "pt-pt": "Portuguese (Portugal)", + "ru": "Russian", + "sv": "Swedish", + "tl": "Filipino", + "tr": "Turkish", + "uk": "Ukrainian", + "zh": "Chinese", + "zh-cn": "Chinese (Simplified)", + "zh-tw": "Chinese (Traditional)", +} diff --git a/toolkits/google_maps/arcade_google_maps/tools/__init__.py b/toolkits/google_maps/arcade_google_maps/tools/__init__.py new file mode 100644 index 00000000..0c6f6dd3 --- /dev/null +++ b/toolkits/google_maps/arcade_google_maps/tools/__init__.py @@ -0,0 +1,9 @@ +from arcade_google_maps.tools.google_maps import ( + get_directions_between_addresses, + get_directions_between_coordinates, +) + +__all__ = [ + "get_directions_between_addresses", + "get_directions_between_coordinates", +] diff --git a/toolkits/google_maps/arcade_google_maps/tools/google_maps.py b/toolkits/google_maps/arcade_google_maps/tools/google_maps.py new file mode 100644 index 00000000..b4458f65 --- /dev/null +++ b/toolkits/google_maps/arcade_google_maps/tools/google_maps.py @@ -0,0 +1,100 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool + +from arcade_google_maps.constants import ( + DEFAULT_GOOGLE_MAPS_COUNTRY, + DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + DEFAULT_GOOGLE_MAPS_LANGUAGE, + DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, +) +from arcade_google_maps.enums import GoogleMapsDistanceUnit, GoogleMapsTravelMode +from arcade_google_maps.utils import get_google_maps_directions + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_directions_between_addresses( + context: ToolContext, + origin_address: Annotated[ + str, "The origin address. Example: '123 Main St, New York, NY 10001'" + ], + destination_address: Annotated[ + str, "The destination address. Example: '456 Main St, New York, NY 10001'" + ], + language: Annotated[ + str, + "2-character language code to use in the Google Maps search. " + f"Defaults to '{DEFAULT_GOOGLE_MAPS_LANGUAGE}'.", + ] = DEFAULT_GOOGLE_MAPS_LANGUAGE, + country: Annotated[ + str | None, + "2-character country code to use in the Google Maps search. " + f"Defaults to '{DEFAULT_GOOGLE_MAPS_COUNTRY}'.", + ] = DEFAULT_GOOGLE_MAPS_COUNTRY, + distance_unit: Annotated[ + GoogleMapsDistanceUnit, + f"Distance unit to use in the Google Maps search. Defaults to " + f"'{DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT}'.", + ] = DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + travel_mode: Annotated[ + GoogleMapsTravelMode, + f"Travel mode to use in the Google Maps search. Defaults to " + f"'{DEFAULT_GOOGLE_MAPS_TRAVEL_MODE}'.", + ] = DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, +) -> Annotated[dict, "The directions from Google Maps"]: + """Get directions from Google Maps.""" + return { + "directions": get_google_maps_directions( + context=context, + origin_address=origin_address, + destination_address=destination_address, + language=language, + country=country, + distance_unit=distance_unit, + travel_mode=travel_mode, + ), + } + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_directions_between_coordinates( + context: ToolContext, + origin_latitude: Annotated[str, "The origin latitude. E.g. '40.7128'"], + origin_longitude: Annotated[str, "The origin longitude. E.g. '-74.0060'"], + destination_latitude: Annotated[str, "The destination latitude. E.g. '40.7128'"], + destination_longitude: Annotated[str, "The destination longitude. E.g. '-74.0060'"], + language: Annotated[ + str, + "2-letter language code to use in the Google Maps search. " + f"Defaults to '{DEFAULT_GOOGLE_MAPS_LANGUAGE}'.", + ] = DEFAULT_GOOGLE_MAPS_LANGUAGE, + country: Annotated[ + str | None, + f"2-letter country code to use in the Google Maps search. Defaults to " + f"'{DEFAULT_GOOGLE_MAPS_COUNTRY}'.", + ] = DEFAULT_GOOGLE_MAPS_COUNTRY, + distance_unit: Annotated[ + GoogleMapsDistanceUnit, + f"Distance unit to use in the Google Maps search. Defaults to " + f"'{DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT}'.", + ] = DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + travel_mode: Annotated[ + GoogleMapsTravelMode, + f"Travel mode to use in the Google Maps search. Defaults to " + f"'{DEFAULT_GOOGLE_MAPS_TRAVEL_MODE}'.", + ] = DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, +) -> Annotated[dict, "The directions from Google Maps"]: + """Get directions from Google Maps.""" + return { + "directions": get_google_maps_directions( + context=context, + origin_latitude=origin_latitude, + origin_longitude=origin_longitude, + destination_latitude=destination_latitude, + destination_longitude=destination_longitude, + language=language, + country=country, + distance_unit=distance_unit, + travel_mode=travel_mode, + ), + } diff --git a/toolkits/google_maps/arcade_google_maps/utils.py b/toolkits/google_maps/arcade_google_maps/utils.py new file mode 100644 index 00000000..a27b781e --- /dev/null +++ b/toolkits/google_maps/arcade_google_maps/utils.py @@ -0,0 +1,175 @@ +import contextlib +import re +from datetime import datetime +from typing import Any, cast +from zoneinfo import ZoneInfo + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + +from arcade_google_maps.constants import ( + DEFAULT_GOOGLE_MAPS_COUNTRY, + DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + DEFAULT_GOOGLE_MAPS_LANGUAGE, + DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, +) +from arcade_google_maps.enums import GoogleMapsDistanceUnit, GoogleMapsTravelMode +from arcade_google_maps.exceptions import CountryNotFoundError, LanguageNotFoundError +from arcade_google_maps.google_data import COUNTRY_CODES, LANGUAGE_CODES + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) + + +def get_google_maps_directions( + context: ToolContext, + origin_address: str | None = None, + destination_address: str | None = None, + origin_latitude: str | None = None, + origin_longitude: str | None = None, + destination_latitude: str | None = None, + destination_longitude: str | None = None, + language: str | None = DEFAULT_GOOGLE_MAPS_LANGUAGE, + country: str | None = DEFAULT_GOOGLE_MAPS_COUNTRY, + distance_unit: GoogleMapsDistanceUnit = DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + travel_mode: GoogleMapsTravelMode = DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, +) -> list[dict[str, Any]]: + """Get directions from Google Maps. + + Provide either all(origin_address, destination_address) or + all(origin_latitude, origin_longitude, destination_latitude, destination_longitude). + + Args: + context: Tool context containing required Serp API Key secret. + origin_address: Origin address. + destination_address: Destination address. + origin_latitude: Origin latitude. + origin_longitude: Origin longitude. + destination_latitude: Destination latitude. + destination_longitude: Destination longitude. + language: Language to use in the Google Maps search. Defaults to 'en' (English). + country: 2-letter country code to use in the Google Maps search. Defaults to None + (no country is specified). + distance_unit: Distance unit to use in the Google Maps search. Defaults to 'km' + (kilometers). + travel_mode: Travel mode to use in the Google Maps search. Defaults to 'best' + (best mode). + + Returns: + The directions from Google Maps. + """ + if isinstance(language, str): + language = language.lower() + + if language not in LANGUAGE_CODES: + raise LanguageNotFoundError(language) + + params = prepare_params( + engine="google_maps_directions", + hl=language, + distance_unit=distance_unit.to_api_value(), + travel_mode=travel_mode.to_api_value(), + ) + + if any([ + origin_latitude, + origin_longitude, + destination_latitude, + destination_longitude, + ]) and any([origin_address, destination_address]): + raise ValueError("Either coordinates or addresses must be provided, not both") + + elif all([origin_latitude, origin_longitude, destination_latitude, destination_longitude]): + params["start_coords"] = f"{origin_latitude},{origin_longitude}" + params["end_coords"] = f"{destination_latitude},{destination_longitude}" + + elif all([origin_address, destination_address]): + params["start_addr"] = str(origin_address) + params["end_addr"] = str(destination_address) + + else: + raise ValueError("Either coordinates or addresses must be provided") + + if country: + country = country.lower() + if country not in COUNTRY_CODES: + raise CountryNotFoundError(country) + params["gl"] = country + + results = call_serpapi(context, params) + + directions = cast(list[dict[str, Any]], results.get("directions", [])) + + for direction in directions: + clean_google_maps_direction(direction) + + if "arrive_around" in direction: + direction["arrive_around"] = enrich_google_maps_arrive_around( + direction["arrive_around"] + ) + + return directions + + +def clean_google_maps_direction(direction: dict[str, Any]) -> None: + for trip in direction.get("trips", []): + with contextlib.suppress(KeyError): + del trip["start_stop"]["data_id"] + del trip["end_stop"]["data_id"] + + for detail in trip.get("details", []): + with contextlib.suppress(KeyError): + del detail["geo_photo"] + del detail["gps_coordinates"] + + for stop in trip.get("stops", []): + with contextlib.suppress(KeyError): + del stop["data_id"] + + +def enrich_google_maps_arrive_around(timestamp: int | None) -> dict[str, Any]: + if not timestamp: + return {} + + dt = datetime.fromtimestamp(timestamp, tz=ZoneInfo("UTC")).isoformat() + return {"datetime": dt, "timestamp": timestamp} diff --git a/toolkits/google_maps/evals/eval_google_maps_directions.py b/toolkits/google_maps/evals/eval_google_maps_directions.py new file mode 100644 index 00000000..ee3f8cb8 --- /dev/null +++ b/toolkits/google_maps/evals/eval_google_maps_directions.py @@ -0,0 +1,226 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_google_maps +from arcade_google_maps.constants import ( + DEFAULT_GOOGLE_MAPS_COUNTRY, + DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + DEFAULT_GOOGLE_MAPS_LANGUAGE, + DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, +) +from arcade_google_maps.enums import GoogleMapsDistanceUnit, GoogleMapsTravelMode +from arcade_google_maps.tools.google_maps import ( + get_directions_between_addresses, + get_directions_between_coordinates, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + +catalog = ToolCatalog() +# Register the Google Search tool +catalog.add_module(arcade_google_maps) + + +@tool_eval() +def google_maps_directions_by_addresses_eval_suite() -> EvalSuite: + """Create an evaluation suite for the Google Maps Directions tools.""" + suite = EvalSuite( + name="Google Maps Directions Tool Evaluation", + system_message="You are an AI assistant that can get directions from Google Maps using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Get directions between two addresses", + user_message="Get directions from Google Maps between the following addresses: 1600 Amphitheatre Parkway, Mountain View, CA 94043 and 1 Infinite Loop, Cupertino, CA 95014.", + expected_tool_calls=[ + ExpectedToolCall( + func=get_directions_between_addresses, + args={ + "origin_address": "1600 Amphitheatre Parkway, Mountain View, CA 94043", + "destination_address": "1 Infinite Loop, Cupertino, CA 95014", + "language": DEFAULT_GOOGLE_MAPS_LANGUAGE, + "country": DEFAULT_GOOGLE_MAPS_COUNTRY, + "distance_unit": DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + "travel_mode": DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="origin_address", weight=0.3), + SimilarityCritic(critic_field="destination_address", weight=0.3), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="country", weight=0.1), + BinaryCritic(critic_field="distance_unit", weight=0.1), + BinaryCritic(critic_field="travel_mode", weight=0.1), + ], + ) + + suite.add_case( + name="Get directions between two addresses with custom distance unit and travel mode", + user_message="Get walking directions from Google Maps between the following addresses in miles: 1600 Amphitheatre Parkway, Mountain View, CA 94043 and 1 Infinite Loop, Cupertino, CA 95014.", + expected_tool_calls=[ + ExpectedToolCall( + func=get_directions_between_addresses, + args={ + "origin_address": "1600 Amphitheatre Parkway, Mountain View, CA 94043", + "destination_address": "1 Infinite Loop, Cupertino, CA 95014", + "language": DEFAULT_GOOGLE_MAPS_LANGUAGE, + "country": DEFAULT_GOOGLE_MAPS_COUNTRY, + "distance_unit": GoogleMapsDistanceUnit.MILES.value, + "travel_mode": GoogleMapsTravelMode.WALKING.value, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="origin_address", weight=0.3), + SimilarityCritic(critic_field="destination_address", weight=0.3), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="country", weight=0.1), + BinaryCritic(critic_field="distance_unit", weight=0.1), + BinaryCritic(critic_field="travel_mode", weight=0.1), + ], + ) + + suite.add_case( + name="Get directions between two addresses in a given country and language", + user_message="Get directions from Google Maps in Portuguese between the following addresses in Brazil: Rua do Amendoim, 1, Belo Horizonte, MG and Av. do Descobrimento, 515, Porto Seguro, BA.", + expected_tool_calls=[ + ExpectedToolCall( + func=get_directions_between_addresses, + args={ + "origin_address": "Rua do Amendoim, 1, Belo Horizonte, MG", + "destination_address": "Av. do Descobrimento, 515, Porto Seguro, BA", + "language": "pt", + "country": "br", + "distance_unit": DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + "travel_mode": DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="origin_address", weight=0.3), + SimilarityCritic(critic_field="destination_address", weight=0.3), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="country", weight=0.1), + BinaryCritic(critic_field="distance_unit", weight=0.1), + BinaryCritic(critic_field="travel_mode", weight=0.1), + ], + ) + + return suite + + +@tool_eval() +def google_maps_directions_by_coordinates_eval_suite() -> EvalSuite: + """Create an evaluation suite for the Google Maps Directions tools.""" + suite = EvalSuite( + name="Google Maps Directions Tool Evaluation", + system_message="You are an AI assistant that can get directions from Google Maps using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Get directions between two coordinates", + user_message="Get directions from Google Maps between the following coordinates: 37.422740,-122.084961 and 37.331820,-122.031180.", + expected_tool_calls=[ + ExpectedToolCall( + func=get_directions_between_coordinates, + args={ + "origin_latitude": "37.422740", + "origin_longitude": "-122.084961", + "destination_latitude": "37.331820", + "destination_longitude": "-122.031180", + "language": DEFAULT_GOOGLE_MAPS_LANGUAGE, + "country": DEFAULT_GOOGLE_MAPS_COUNTRY, + "distance_unit": DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + "travel_mode": DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="origin_latitude", weight=0.15), + SimilarityCritic(critic_field="origin_longitude", weight=0.15), + SimilarityCritic(critic_field="destination_latitude", weight=0.15), + SimilarityCritic(critic_field="destination_longitude", weight=0.15), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="country", weight=0.1), + BinaryCritic(critic_field="distance_unit", weight=0.1), + BinaryCritic(critic_field="travel_mode", weight=0.1), + ], + ) + + suite.add_case( + name="Get directions between two coordinates with custom distance unit and travel mode", + user_message="Get walking directions from Google Maps between the following coordinates in miles: 37.422740,-122.084961 and 37.331820,-122.031180.", + expected_tool_calls=[ + ExpectedToolCall( + func=get_directions_between_coordinates, + args={ + "origin_latitude": "37.422740", + "origin_longitude": "-122.084961", + "destination_latitude": "37.331820", + "destination_longitude": "-122.031180", + "language": DEFAULT_GOOGLE_MAPS_LANGUAGE, + "country": DEFAULT_GOOGLE_MAPS_COUNTRY, + "distance_unit": GoogleMapsDistanceUnit.MILES.value, + "travel_mode": GoogleMapsTravelMode.WALKING.value, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="origin_latitude", weight=0.15), + SimilarityCritic(critic_field="origin_longitude", weight=0.15), + SimilarityCritic(critic_field="destination_latitude", weight=0.15), + SimilarityCritic(critic_field="destination_longitude", weight=0.15), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="country", weight=0.1), + BinaryCritic(critic_field="distance_unit", weight=0.1), + BinaryCritic(critic_field="travel_mode", weight=0.1), + ], + ) + + suite.add_case( + name="Get directions between two coordinates in a given country and language", + user_message="Get directions from Google Maps in Portuguese between the following coordinates in Brazil: 37.422740,-122.084961 and 37.331820,-122.031180.", + expected_tool_calls=[ + ExpectedToolCall( + func=get_directions_between_coordinates, + args={ + "origin_latitude": "37.422740", + "origin_longitude": "-122.084961", + "destination_latitude": "37.331820", + "destination_longitude": "-122.031180", + "language": "pt", + "country": "br", + "distance_unit": DEFAULT_GOOGLE_MAPS_DISTANCE_UNIT, + "travel_mode": DEFAULT_GOOGLE_MAPS_TRAVEL_MODE, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="origin_latitude", weight=0.15), + SimilarityCritic(critic_field="origin_longitude", weight=0.15), + SimilarityCritic(critic_field="destination_latitude", weight=0.15), + SimilarityCritic(critic_field="destination_longitude", weight=0.15), + BinaryCritic(critic_field="language", weight=0.1), + BinaryCritic(critic_field="country", weight=0.1), + BinaryCritic(critic_field="distance_unit", weight=0.1), + BinaryCritic(critic_field="travel_mode", weight=0.1), + ], + ) + + return suite diff --git a/toolkits/google_maps/pyproject.toml b/toolkits/google_maps/pyproject.toml new file mode 100644 index 00000000..7adb3620 --- /dev/null +++ b/toolkits/google_maps/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_maps" +version = "2.0.0" +description = "Arcade.dev LLM tools for getting directions via Google Maps" +requires-python = ">=3.10" +dependencies = [ "arcade-tdk>=2.0.0,<3.0.0", "serpapi>=0.1.5,<1.0.0",] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<9.0.0", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.11.1,<4.0.0", + "pytest-asyncio>=0.24.0,<1.0.0", + "mypy>=1.5.1,<2.0.0", + "pre-commit>=3.4.0,<4.0.0", + "tox>=4.11.1,<5.0.0", + "ruff>=0.7.4,<1.0.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_maps/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_maps",] diff --git a/toolkits/google_maps/tests/__init__.py b/toolkits/google_maps/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_maps/tests/test_google_maps_directions.py b/toolkits/google_maps/tests/test_google_maps_directions.py new file mode 100644 index 00000000..46624f69 --- /dev/null +++ b/toolkits/google_maps/tests/test_google_maps_directions.py @@ -0,0 +1,131 @@ +from unittest.mock import patch + +import pytest +from arcade_tdk import ToolContext, ToolSecretItem + +from arcade_google_maps.exceptions import CountryNotFoundError, LanguageNotFoundError +from arcade_google_maps.tools.google_maps import ( + get_directions_between_addresses, + get_directions_between_coordinates, +) + + +@pytest.fixture +def mock_context(): + return ToolContext(secrets=[ToolSecretItem(key="serp_api_key", value="fake_api_key")]) + + +@pytest.mark.asyncio +@patch("arcade_google_maps.utils.SerpClient") +async def test_get_directions_between_coordinates_success(mock_serp_client, mock_context): + mock_serp_client_instance = mock_serp_client.return_value + mock_serp_client_instance.search.return_value.as_dict.return_value = { + "directions": [ + { + "arrive_around": 1741789839, + "distance": "100 miles", + "duration": "1 hour", + } + ] + } + + result = await get_directions_between_coordinates( + context=mock_context, + origin_latitude="1", + origin_longitude="2", + destination_latitude="3", + destination_longitude="4", + ) + + assert result == { + "directions": [ + { + "arrive_around": { + "datetime": "2025-03-12T14:30:39+00:00", + "timestamp": 1741789839, + }, + "distance": "100 miles", + "duration": "1 hour", + } + ] + } + + +@pytest.mark.asyncio +@patch("arcade_google_maps.utils.SerpClient") +async def test_get_directions_between_addresses_success(mock_serp_client, mock_context): + mock_serp_client_instance = mock_serp_client.return_value + mock_serp_client_instance.search.return_value.as_dict.return_value = { + "directions": [ + { + "arrive_around": 1741789839, + "distance": "100 miles", + "duration": "1 hour", + } + ] + } + + result = await get_directions_between_addresses( + context=mock_context, + origin_address="1", + destination_address="2", + ) + + assert result == { + "directions": [ + { + "arrive_around": { + "datetime": "2025-03-12T14:30:39+00:00", + "timestamp": 1741789839, + }, + "distance": "100 miles", + "duration": "1 hour", + } + ] + } + + +@pytest.mark.asyncio +@patch("arcade_google_maps.utils.SerpClient") +async def test_get_directions_between_addresses_country_not_found(mock_serp_client, mock_context): + mock_serp_client_instance = mock_serp_client.return_value + mock_serp_client_instance.search.return_value.as_dict.return_value = { + "directions": [ + { + "arrive_around": 1741789839, + "distance": "100 miles", + "duration": "1 hour", + } + ] + } + + with pytest.raises(CountryNotFoundError): + await get_directions_between_addresses( + context=mock_context, + origin_address="1", + destination_address="2", + country="invalid", + ) + + +@pytest.mark.asyncio +@patch("arcade_google_maps.utils.SerpClient") +async def test_get_directions_between_addresses_language_not_found(mock_serp_client, mock_context): + mock_serp_client_instance = mock_serp_client.return_value + mock_serp_client_instance.search.return_value.as_dict.return_value = { + "directions": [ + { + "arrive_around": 1741789839, + "distance": "100 miles", + "duration": "1 hour", + } + ] + } + + with pytest.raises(LanguageNotFoundError): + await get_directions_between_addresses( + context=mock_context, + origin_address="1", + destination_address="2", + language="invalid", + ) diff --git a/toolkits/google_news/.pre-commit-config.yaml b/toolkits/google_news/.pre-commit-config.yaml new file mode 100644 index 00000000..4b9f271a --- /dev/null +++ b/toolkits/google_news/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_news/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_news/.ruff.toml b/toolkits/google_news/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/google_news/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_news/LICENSE b/toolkits/google_news/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/google_news/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_news/Makefile b/toolkits/google_news/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_news/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_news/arcade_google_news/__init__.py b/toolkits/google_news/arcade_google_news/__init__.py new file mode 100644 index 00000000..e371d62b --- /dev/null +++ b/toolkits/google_news/arcade_google_news/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_news.tools import search_news_stories + +__all__ = ["search_news_stories"] diff --git a/toolkits/google_news/arcade_google_news/constants.py b/toolkits/google_news/arcade_google_news/constants.py new file mode 100644 index 00000000..cb5183ae --- /dev/null +++ b/toolkits/google_news/arcade_google_news/constants.py @@ -0,0 +1,6 @@ +import os + +DEFAULT_GOOGLE_LANGUAGE = os.getenv("ARCADE_GOOGLE_LANGUAGE", "en") + +DEFAULT_GOOGLE_NEWS_LANGUAGE = os.getenv("ARCADE_GOOGLE_NEWS_LANGUAGE", DEFAULT_GOOGLE_LANGUAGE) +DEFAULT_GOOGLE_NEWS_COUNTRY = os.getenv("ARCADE_GOOGLE_NEWS_COUNTRY") diff --git a/toolkits/google_news/arcade_google_news/exceptions.py b/toolkits/google_news/arcade_google_news/exceptions.py new file mode 100644 index 00000000..480065af --- /dev/null +++ b/toolkits/google_news/arcade_google_news/exceptions.py @@ -0,0 +1,25 @@ +import json + +from arcade_tdk.errors import RetryableToolError + +from arcade_google_news.google_data import COUNTRY_CODES, LANGUAGE_CODES + + +class GoogleRetryableError(RetryableToolError): + pass + + +class CountryNotFoundError(GoogleRetryableError): + def __init__(self, country: str | None) -> None: + valid_countries = json.dumps(COUNTRY_CODES, default=str) + message = f"Country not found: '{country}'." + additional_message = f"Valid countries are: {valid_countries}" + super().__init__(message, additional_prompt_content=additional_message) + + +class LanguageNotFoundError(GoogleRetryableError): + def __init__(self, language: str | None) -> None: + valid_languages = json.dumps(LANGUAGE_CODES, default=str) + message = f"Language not found: '{language}'." + additional_message = f"Valid languages are: {valid_languages}" + super().__init__(message, additional_prompt_content=additional_message) diff --git a/toolkits/google_news/arcade_google_news/google_data.py b/toolkits/google_news/arcade_google_news/google_data.py new file mode 100644 index 00000000..789e3183 --- /dev/null +++ b/toolkits/google_news/arcade_google_news/google_data.py @@ -0,0 +1,281 @@ +COUNTRY_CODES = { + "af": "Afghanistan", + "al": "Albania", + "dz": "Algeria", + "as": "American Samoa", + "ad": "Andorra", + "ao": "Angola", + "ai": "Anguilla", + "aq": "Antarctica", + "ag": "Antigua and Barbuda", + "ar": "Argentina", + "am": "Armenia", + "aw": "Aruba", + "au": "Australia", + "at": "Austria", + "az": "Azerbaijan", + "bs": "Bahamas", + "bh": "Bahrain", + "bd": "Bangladesh", + "bb": "Barbados", + "by": "Belarus", + "be": "Belgium", + "bz": "Belize", + "bj": "Benin", + "bm": "Bermuda", + "bt": "Bhutan", + "bo": "Bolivia", + "ba": "Bosnia and Herzegovina", + "bw": "Botswana", + "bv": "Bouvet Island", + "br": "Brazil", + "io": "British Indian Ocean Territory", + "bn": "Brunei Darussalam", + "bg": "Bulgaria", + "bf": "Burkina Faso", + "bi": "Burundi", + "kh": "Cambodia", + "cm": "Cameroon", + "ca": "Canada", + "cv": "Cape Verde", + "ky": "Cayman Islands", + "cf": "Central African Republic", + "td": "Chad", + "cl": "Chile", + "cn": "China", + "cx": "Christmas Island", + "cc": "Cocos (Keeling) Islands", + "co": "Colombia", + "km": "Comoros", + "cg": "Congo", + "cd": "Congo, the Democratic Republic of the", + "ck": "Cook Islands", + "cr": "Costa Rica", + "ci": "Cote D'ivoire", + "hr": "Croatia", + "cu": "Cuba", + "cy": "Cyprus", + "cz": "Czech Republic", + "dk": "Denmark", + "dj": "Djibouti", + "dm": "Dominica", + "do": "Dominican Republic", + "ec": "Ecuador", + "eg": "Egypt", + "sv": "El Salvador", + "gq": "Equatorial Guinea", + "er": "Eritrea", + "ee": "Estonia", + "et": "Ethiopia", + "fk": "Falkland Islands (Malvinas)", + "fo": "Faroe Islands", + "fj": "Fiji", + "fi": "Finland", + "fr": "France", + "gf": "French Guiana", + "pf": "French Polynesia", + "tf": "French Southern Territories", + "ga": "Gabon", + "gm": "Gambia", + "ge": "Georgia", + "de": "Germany", + "gh": "Ghana", + "gi": "Gibraltar", + "gr": "Greece", + "gl": "Greenland", + "gd": "Grenada", + "gp": "Guadeloupe", + "gu": "Guam", + "gt": "Guatemala", + "gg": "Guernsey", + "gn": "Guinea", + "gw": "Guinea-Bissau", + "gy": "Guyana", + "ht": "Haiti", + "hm": "Heard Island and Mcdonald Islands", + "va": "Holy See (Vatican City State)", + "hn": "Honduras", + "hk": "Hong Kong", + "hu": "Hungary", + "is": "Iceland", + "in": "India", + "id": "Indonesia", + "ir": "Iran, Islamic Republic of", + "iq": "Iraq", + "ie": "Ireland", + "im": "Isle of Man", + "il": "Israel", + "it": "Italy", + "je": "Jersey", + "jm": "Jamaica", + "jp": "Japan", + "jo": "Jordan", + "kz": "Kazakhstan", + "ke": "Kenya", + "ki": "Kiribati", + "kp": "Korea, Democratic People's Republic of", + "kr": "Korea, Republic of", + "kw": "Kuwait", + "kg": "Kyrgyzstan", + "la": "Lao People's Democratic Republic", + "lv": "Latvia", + "lb": "Lebanon", + "ls": "Lesotho", + "lr": "Liberia", + "ly": "Libyan Arab Jamahiriya", + "li": "Liechtenstein", + "lt": "Lithuania", + "lu": "Luxembourg", + "mo": "Macao", + "mk": "Macedonia, the Former Yugosalv Republic of", + "mg": "Madagascar", + "mw": "Malawi", + "my": "Malaysia", + "mv": "Maldives", + "ml": "Mali", + "mt": "Malta", + "mh": "Marshall Islands", + "mq": "Martinique", + "mr": "Mauritania", + "mu": "Mauritius", + "yt": "Mayotte", + "mx": "Mexico", + "fm": "Micronesia, Federated States of", + "md": "Moldova, Republic of", + "mc": "Monaco", + "mn": "Mongolia", + "me": "Montenegro", + "ms": "Montserrat", + "ma": "Morocco", + "mz": "Mozambique", + "mm": "Myanmar", + "na": "Namibia", + "nr": "Nauru", + "np": "Nepal", + "nl": "Netherlands", + "an": "Netherlands Antilles", + "nc": "New Caledonia", + "nz": "New Zealand", + "ni": "Nicaragua", + "ne": "Niger", + "ng": "Nigeria", + "nu": "Niue", + "nf": "Norfolk Island", + "mp": "Northern Mariana Islands", + "no": "Norway", + "om": "Oman", + "pk": "Pakistan", + "pw": "Palau", + "ps": "Palestinian Territory, Occupied", + "pa": "Panama", + "pg": "Papua New Guinea", + "py": "Paraguay", + "pe": "Peru", + "ph": "Philippines", + "pn": "Pitcairn", + "pl": "Poland", + "pt": "Portugal", + "pr": "Puerto Rico", + "qa": "Qatar", + "re": "Reunion", + "ro": "Romania", + "ru": "Russian Federation", + "rw": "Rwanda", + "sh": "Saint Helena", + "kn": "Saint Kitts and Nevis", + "lc": "Saint Lucia", + "pm": "Saint Pierre and Miquelon", + "vc": "Saint Vincent and the Grenadines", + "ws": "Samoa", + "sm": "San Marino", + "st": "Sao Tome and Principe", + "sa": "Saudi Arabia", + "sn": "Senegal", + "rs": "Serbia", + "sc": "Seychelles", + "sl": "Sierra Leone", + "sg": "Singapore", + "sk": "Slovakia", + "si": "Slovenia", + "sb": "Solomon Islands", + "so": "Somalia", + "za": "South Africa", + "gs": "South Georgia and the South Sandwich Islands", + "es": "Spain", + "lk": "Sri Lanka", + "sd": "Sudan", + "sr": "Suriname", + "sj": "Svalbard and Jan Mayen", + "sz": "Swaziland", + "se": "Sweden", + "ch": "Switzerland", + "sy": "Syrian Arab Republic", + "tw": "Taiwan, Province of China", + "tj": "Tajikistan", + "tz": "Tanzania, United Republic of", + "th": "Thailand", + "tl": "Timor-Leste", + "tg": "Togo", + "tk": "Tokelau", + "to": "Tonga", + "tt": "Trinidad and Tobago", + "tn": "Tunisia", + "tr": "Turkiye", + "tm": "Turkmenistan", + "tc": "Turks and Caicos Islands", + "tv": "Tuvalu", + "ug": "Uganda", + "ua": "Ukraine", + "ae": "United Arab Emirates", + "uk": "United Kingdom", + "gb": "United Kingdom", + "us": "United States", + "um": "United States Minor Outlying Islands", + "uy": "Uruguay", + "uz": "Uzbekistan", + "vu": "Vanuatu", + "ve": "Venezuela", + "vn": "Viet Nam", + "vg": "Virgin Islands, British", + "vi": "Virgin Islands, U.S.", + "wf": "Wallis and Futuna", + "eh": "Western Sahara", + "ye": "Yemen", + "zm": "Zambia", + "zw": "Zimbabwe", +} + + +LANGUAGE_CODES = { + "ar": "Arabic", + "bn": "Bengali", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "fi": "Finnish", + "fr": "French", + "hi": "Hindi", + "hu": "Hungarian", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "ko": "Korean", + "nl": "Dutch", + "ms": "Malay", + "no": "Norwegian", + "pcm": "Nigerian Pidgin", + "pl": "Polish", + "pt": "Portuguese", + "pt-br": "Portuguese (Brazil)", + "pt-pt": "Portuguese (Portugal)", + "ru": "Russian", + "sv": "Swedish", + "tl": "Filipino", + "tr": "Turkish", + "uk": "Ukrainian", + "zh": "Chinese", + "zh-cn": "Chinese (Simplified)", + "zh-tw": "Chinese (Traditional)", +} diff --git a/toolkits/google_news/arcade_google_news/tools/__init__.py b/toolkits/google_news/arcade_google_news/tools/__init__.py new file mode 100644 index 00000000..d5d1050d --- /dev/null +++ b/toolkits/google_news/arcade_google_news/tools/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_news.tools.google_news import search_news_stories + +__all__ = ["search_news_stories"] diff --git a/toolkits/google_news/arcade_google_news/tools/google_news.py b/toolkits/google_news/arcade_google_news/tools/google_news.py new file mode 100644 index 00000000..73c81a40 --- /dev/null +++ b/toolkits/google_news/arcade_google_news/tools/google_news.py @@ -0,0 +1,47 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from arcade_tdk.errors import ToolExecutionError + +from arcade_google_news.constants import DEFAULT_GOOGLE_NEWS_COUNTRY, DEFAULT_GOOGLE_NEWS_LANGUAGE +from arcade_google_news.exceptions import CountryNotFoundError, LanguageNotFoundError +from arcade_google_news.google_data import COUNTRY_CODES, LANGUAGE_CODES +from arcade_google_news.utils import call_serpapi, extract_news_results, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_news_stories( + context: ToolContext, + keywords: Annotated[ + str, + "Keywords to search for news articles. E.g. 'Apple launches new iPhone'.", + ], + country_code: Annotated[ + str | None, + "2-character country code to search for news articles. E.g. 'us' (United States). " + f"Defaults to '{DEFAULT_GOOGLE_NEWS_COUNTRY}'.", + ] = None, + language_code: Annotated[ + str, + "2-character language code to search for news articles. E.g. 'en' (English). " + f"Defaults to '{DEFAULT_GOOGLE_NEWS_LANGUAGE}'.", + ] = DEFAULT_GOOGLE_NEWS_LANGUAGE, + limit: Annotated[ + int | None, + "Maximum number of news articles to return. Defaults to None " + "(returns all results found by the API).", + ] = None, +) -> Annotated[dict[str, list[dict[str, Any]]], "News results."]: + """Search for news articles related to a given query.""" + if not keywords: + raise ToolExecutionError("Keywords are required to search for news articles.") + + if country_code and country_code not in COUNTRY_CODES: + raise CountryNotFoundError(country_code) + + if language_code not in LANGUAGE_CODES: + raise LanguageNotFoundError(language_code) + + params = prepare_params("google_news", q=keywords, gl=country_code, hl=language_code) + results = call_serpapi(context, params) + return {"news_results": extract_news_results(results, limit=limit)} diff --git a/toolkits/google_news/arcade_google_news/utils.py b/toolkits/google_news/arcade_google_news/utils.py new file mode 100644 index 00000000..a401b7eb --- /dev/null +++ b/toolkits/google_news/arcade_google_news/utils.py @@ -0,0 +1,64 @@ +import re +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) + + +def extract_news_results(results: dict[str, Any], limit: int | None = None) -> list[dict[str, Any]]: + news_results = [] + for result in results.get("news_results", []): + news_results.append({ + "title": result.get("title"), + "snippet": result.get("snippet"), + "link": result.get("link"), + "date": result.get("date"), + "source": result.get("source", {}).get("name"), + }) + + if limit: + return news_results[:limit] + return news_results diff --git a/toolkits/google_news/pyproject.toml b/toolkits/google_news/pyproject.toml new file mode 100644 index 00000000..36513703 --- /dev/null +++ b/toolkits/google_news/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_news" +version = "2.0.0" +description = "Arcade.dev LLM tools for getting new via Google News" +requires-python = ">=3.10" +dependencies = [ "arcade-tdk>=2.0.0,<3.0.0", "serpapi>=0.1.5,<1.0.0",] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<9.0.0", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.11.1,<4.0.0", + "pytest-asyncio>=0.24.0,<1.0.0", + "mypy>=1.5.1,<2.0.0", + "pre-commit>=3.4.0,<4.0.0", + "tox>=4.11.1,<5.0.0", + "ruff>=0.7.4,<1.0.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_news/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_news",] diff --git a/toolkits/google_search/.pre-commit-config.yaml b/toolkits/google_search/.pre-commit-config.yaml new file mode 100644 index 00000000..d4c8cef7 --- /dev/null +++ b/toolkits/google_search/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_search/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_search/.ruff.toml b/toolkits/google_search/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/google_search/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_search/LICENSE b/toolkits/google_search/LICENSE new file mode 100644 index 00000000..45f53e20 --- /dev/null +++ b/toolkits/google_search/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Arcade + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_search/Makefile b/toolkits/google_search/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_search/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_search/arcade_google_search/__init__.py b/toolkits/google_search/arcade_google_search/__init__.py new file mode 100644 index 00000000..42837602 --- /dev/null +++ b/toolkits/google_search/arcade_google_search/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_search.tools import search + +__all__ = ["search"] diff --git a/toolkits/google_search/arcade_google_search/tools/__init__.py b/toolkits/google_search/arcade_google_search/tools/__init__.py new file mode 100644 index 00000000..b8618dc0 --- /dev/null +++ b/toolkits/google_search/arcade_google_search/tools/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_search.tools.google_search import search + +__all__ = ["search"] diff --git a/toolkits/google_search/arcade_google_search/tools/google_search.py b/toolkits/google_search/arcade_google_search/tools/google_search.py new file mode 100644 index 00000000..288e97cc --- /dev/null +++ b/toolkits/google_search/arcade_google_search/tools/google_search.py @@ -0,0 +1,21 @@ +import json +from typing import Annotated + +from arcade_tdk import ToolContext, tool + +from arcade_google_search.utils import call_serpapi, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search( + context: ToolContext, + query: Annotated[str, "Search query"], + n_results: Annotated[int, "Number of results to retrieve"] = 5, +) -> str: + """Search Google using SerpAPI and return organic search results.""" + + params = prepare_params("google", q=query) + results = call_serpapi(context, params) + organic_results = results.get("organic_results", []) + + return json.dumps(organic_results[:n_results]) diff --git a/toolkits/google_search/arcade_google_search/utils.py b/toolkits/google_search/arcade_google_search/utils.py new file mode 100644 index 00000000..00c0dcba --- /dev/null +++ b/toolkits/google_search/arcade_google_search/utils.py @@ -0,0 +1,48 @@ +import re +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) diff --git a/toolkits/google_search/evals/eval_google_search.py b/toolkits/google_search/evals/eval_google_search.py new file mode 100644 index 00000000..662db4fb --- /dev/null +++ b/toolkits/google_search/evals/eval_google_search.py @@ -0,0 +1,240 @@ +from arcade_evals import ( + EvalRubric, + EvalSuite, + ExpectedToolCall, + NumericCritic, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_google_search +from arcade_google_search.tools import search + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + +catalog = ToolCatalog() +# Register the Google Search tool +catalog.add_module(arcade_google_search) + + +@tool_eval() +def google_search_eval_suite() -> EvalSuite: + """Create an evaluation suite for the Google Search tool.""" + suite = EvalSuite( + name="Google Search Tool Evaluation", + system_message="You are an AI assistant that can perform web searches using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + # Simple search query with default results + suite.add_case( + name="Simple search query with default results", + user_message="Search for 'Climate change effects on polar bears' on Google.", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "Climate change effects on polar bears", + "n_results": 5, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=1.0), + ], + ) + + # Search query with specific number of results + suite.add_case( + name="Search query with specific number of results", + user_message="Find the top 3 articles about quantum computing.", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "articles about quantum computing", + "n_results": 3, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=0.7), + NumericCritic( + critic_field="n_results", + weight=0.3, + value_range=(1, 100), + ), + ], + ) + + # Search query with 'n' results specified in words + suite.add_case( + name="Search query with 'n' results specified in words", + user_message="Give me five recipes for vegan lasagna.", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "recipes for vegan lasagna", + "n_results": 5, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=0.7), + NumericCritic( + critic_field="n_results", + weight=0.3, + value_range=(1, 100), + ), + ], + ) + + # Ambiguous number of results + suite.add_case( + name="Ambiguous number of results", + user_message="Find articles about climate change impacts 10.", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "articles about climate change impacts 10", + "n_results": 5, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=1.0), + ], + ) + + # Search query with multiple instructions + suite.add_case( + name="Search query with multiple instructions", + user_message="Search for the latest news on electric cars, and tell me about Tesla's new model.", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "latest news on electric cars", + "n_results": 5, + }, + ), + ExpectedToolCall( + func=search, + args={ + "query": "Tesla's new model", + "n_results": 5, + }, + ), + ], + critics=[ + SimilarityCritic(critic_field="query", weight=1.0), + ], + ) + + # Search with stop words and filler words + suite.add_case( + name="Search with stop words and filler words", + user_message="Could you please search for the best ways to learn French?", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "best ways to learn French", + "n_results": 5, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=1.0), + ], + ) + + # No clear query given + suite.add_case( + name="No clear query given", + user_message="Find it for me.", + expected_tool_calls=[], + critics=[], + ) + + # Search query with special characters + suite.add_case( + name="Search query with special characters", + user_message="Find me '@OpenAI's latest research papers'", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "@OpenAI's latest research papers", + "n_results": 5, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=1.0), + ], + ) + + # Search query with complex instructions + suite.add_case( + name="Search query with complex instructions", + user_message="I need information about the impact of deforestation in the Amazon over the past decade.", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "impact of deforestation in the Amazon over the past decade", + "n_results": 5, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=1.0), + ], + ) + + # Search query in a different language + suite.add_case( + name="Search query in a different language", + user_message="Busca información sobre la economía de España.", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "economía de España", + "n_results": 5, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=1.0), + ], + ) + + # Search query with numeric data + suite.add_case( + name="Search query with numeric data", + user_message="What was the population of Japan in 2020?", + expected_tool_calls=[ + ExpectedToolCall( + func=search, + args={ + "query": "population of Japan in 2020", + "n_results": 5, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="query", weight=1.0), + ], + ) + + return suite diff --git a/toolkits/google_search/pyproject.toml b/toolkits/google_search/pyproject.toml new file mode 100644 index 00000000..412b3915 --- /dev/null +++ b/toolkits/google_search/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_search" +version = "2.0.0" +description = "Arcade.dev LLM tools for searching via Google" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "serpapi>=0.1.5,<1.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_search/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_search",] diff --git a/toolkits/google_search/tests/__init__.py b/toolkits/google_search/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_search/tests/test_google_search.py b/toolkits/google_search/tests/test_google_search.py new file mode 100644 index 00000000..69a90c96 --- /dev/null +++ b/toolkits/google_search/tests/test_google_search.py @@ -0,0 +1,49 @@ +import json +from unittest.mock import patch + +import pytest +from arcade_tdk import ToolContext, ToolSecretItem + +from arcade_google_search.tools import search + + +@pytest.fixture +def mock_context(): + return ToolContext(secrets=[ToolSecretItem(key="serp_api_key", value="fake_api_key")]) + + +@pytest.mark.asyncio +async def test_search_google_success(mock_context): + with ( + patch("arcade_google_search.utils.SerpClient") as MockClient, + ): + mock_client_instance = MockClient.return_value + mock_client_instance.search.return_value.as_dict.return_value = { + "organic_results": [ + {"title": "Result 1", "link": "http://example.com/1"}, + {"title": "Result 2", "link": "http://example.com/2"}, + {"title": "Result 3", "link": "http://example.com/3"}, + ] + } + + result = await search(mock_context, "test query", 2) + + expected_result = json.dumps([ + {"title": "Result 1", "link": "http://example.com/1"}, + {"title": "Result 2", "link": "http://example.com/2"}, + ]) + assert result == expected_result + + +@pytest.mark.asyncio +async def test_search_google_no_results(mock_context): + with ( + patch("arcade_google_search.utils.SerpClient") as MockClient, + ): + mock_client_instance = MockClient.return_value + mock_client_instance.search.return_value.as_dict.return_value = {"organic_results": []} + + result = await search(mock_context, "test query", 2) + + expected_result = json.dumps([]) + assert result == expected_result diff --git a/toolkits/google_search/tests/test_utils.py b/toolkits/google_search/tests/test_utils.py new file mode 100644 index 00000000..0bbbd3e9 --- /dev/null +++ b/toolkits/google_search/tests/test_utils.py @@ -0,0 +1,68 @@ +import pytest +import serpapi +from arcade_tdk.errors import ToolExecutionError + +from arcade_google_search.utils import call_serpapi, prepare_params + + +class DummyContext: + def get_secret(self, key: str) -> str | None: + if key.lower() == "serp_api_key": + return "dummy_key" + return None + + +@pytest.fixture +def dummy_context(): + return DummyContext() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "engine, kwargs, expected", + [ + ("google", {}, {"engine": "google"}), + ( + "google", + {"q": "test", "window": 10, "time": "00:12:12"}, + { + "engine": "google", + "q": "test", + "window": 10, + "time": "00:12:12", + }, + ), + ], +) +async def test_prepare_params(engine, kwargs, expected): + params = prepare_params(engine, **kwargs) + assert params == expected + + +@pytest.mark.parametrize( + "error_message, sanitized_message", + [ + ( + "You hit your rate limit", + "You hit your rate limit", + ), + ( + "Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=ABC123456", + "Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=***", + ), + ( + "Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=ABC123456 make sure the api key is correct", + "Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=*** make sure the api key is correct", + ), + ], +) +def test_call_serpapi_failure(monkeypatch, dummy_context, error_message, sanitized_message): + def fake_serpapi_search(self, params: dict) -> dict: + raise Exception(error_message) # noqa: TRY002 + + monkeypatch.setattr(serpapi.Client, "search", fake_serpapi_search) + + with pytest.raises(ToolExecutionError) as excinfo: + call_serpapi(dummy_context, {}) + + assert excinfo.value.developer_message == sanitized_message diff --git a/toolkits/google_sheets/.pre-commit-config.yaml b/toolkits/google_sheets/.pre-commit-config.yaml new file mode 100644 index 00000000..4baefff8 --- /dev/null +++ b/toolkits/google_sheets/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_sheets/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_sheets/.ruff.toml b/toolkits/google_sheets/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/google_sheets/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_sheets/LICENSE b/toolkits/google_sheets/LICENSE new file mode 100644 index 00000000..45f53e20 --- /dev/null +++ b/toolkits/google_sheets/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Arcade + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_sheets/Makefile b/toolkits/google_sheets/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_sheets/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_sheets/arcade_google_sheets/__init__.py b/toolkits/google_sheets/arcade_google_sheets/__init__.py new file mode 100644 index 00000000..97ee1d19 --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/__init__.py @@ -0,0 +1,7 @@ +from arcade_google_sheets.tools import ( + create_spreadsheet, + get_spreadsheet, + write_to_cell, +) + +__all__ = ["create_spreadsheet", "get_spreadsheet", "write_to_cell"] diff --git a/toolkits/google_sheets/arcade_google_sheets/constants.py b/toolkits/google_sheets/arcade_google_sheets/constants.py new file mode 100644 index 00000000..20fd986e --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/constants.py @@ -0,0 +1,2 @@ +DEFAULT_SHEET_ROW_COUNT = 1000 +DEFAULT_SHEET_COLUMN_COUNT = 26 diff --git a/toolkits/google_sheets/arcade_google_sheets/decorators.py b/toolkits/google_sheets/arcade_google_sheets/decorators.py new file mode 100644 index 00000000..0760576c --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/decorators.py @@ -0,0 +1,24 @@ +import functools +from collections.abc import Callable +from typing import Any + +from arcade_tdk import ToolContext +from googleapiclient.errors import HttpError + +from arcade_google_sheets.file_picker import generate_google_file_picker_url + + +def with_filepicker_fallback(func: Callable[..., Any]) -> Callable[..., Any]: + """ """ + + @functools.wraps(func) + async def async_wrapper(context: ToolContext, *args: Any, **kwargs: Any) -> Any: + try: + return await func(context, *args, **kwargs) + except HttpError as e: + if e.status_code in [403, 404]: + file_picker_response = generate_google_file_picker_url(context) + return file_picker_response + raise + + return async_wrapper diff --git a/toolkits/google_sheets/arcade_google_sheets/enums.py b/toolkits/google_sheets/arcade_google_sheets/enums.py new file mode 100644 index 00000000..a836f0b8 --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/enums.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class CellErrorType(str, Enum): + """The type of error in a cell + + Implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/other#ErrorType + """ + + ERROR_TYPE_UNSPECIFIED = "ERROR_TYPE_UNSPECIFIED" # The default error type, do not use this. + ERROR = "ERROR" # Corresponds to the #ERROR! error. + NULL_VALUE = "NULL_VALUE" # Corresponds to the #NULL! error. + DIVIDE_BY_ZERO = "DIVIDE_BY_ZERO" # Corresponds to the #DIV/0 error. + VALUE = "VALUE" # Corresponds to the #VALUE! error. + REF = "REF" # Corresponds to the #REF! error. + NAME = "NAME" # Corresponds to the #NAME? error. + NUM = "NUM" # Corresponds to the #NUM! error. + N_A = "N_A" # Corresponds to the #N/A error. + LOADING = "LOADING" # Corresponds to the Loading... state. + + +class NumberFormatType(str, Enum): + NUMBER = "NUMBER" + PERCENT = "PERCENT" + CURRENCY = "CURRENCY" diff --git a/toolkits/google_sheets/arcade_google_sheets/file_picker.py b/toolkits/google_sheets/arcade_google_sheets/file_picker.py new file mode 100644 index 00000000..193690ef --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/file_picker.py @@ -0,0 +1,49 @@ +import base64 +import json + +from arcade_tdk import ToolContext, ToolMetadataKey +from arcade_tdk.errors import ToolExecutionError + + +def generate_google_file_picker_url(context: ToolContext) -> dict: + """Generate a Google File Picker URL for user-driven file selection and authorization. + + Generates a URL that directs the end-user to a Google File Picker interface where + where they can select or upload Google Drive files. Users can grant permission to access their + Drive files, providing a secure and authorized way to interact with their files. + + This is particularly useful when prior tools (e.g., those accessing or modifying + Google Docs, Google Sheets, etc.) encountered failures due to file non-existence + (Requested entity was not found) or permission errors. Once the user completes the file + picker flow, the prior tool can be retried. + + Returns: + A dictionary containing the URL and instructions for the llm to instruct the user. + """ + client_id = context.get_metadata(ToolMetadataKey.CLIENT_ID) + client_id_parts = client_id.split("-") + if not client_id_parts: + raise ToolExecutionError( + message="Invalid Google Client ID", + developer_message=f"Google Client ID '{client_id}' is not valid", + ) + app_id = client_id_parts[0] + cloud_coordinator_url = context.get_metadata(ToolMetadataKey.COORDINATOR_URL).strip("/") + + config = { + "auth": { + "client_id": client_id, + "app_id": app_id, + }, + } + config_json = json.dumps(config) + config_base64 = base64.urlsafe_b64encode(config_json.encode("utf-8")).decode("utf-8") + url = f"{cloud_coordinator_url}/google/drive_picker?config={config_base64}" + + return { + "url": url, + "llm_instructions": ( + "Instruct the user to click the following link to open the Google Drive File Picker. " + f"This will allow them to select files and grant access permissions: {url}" + ), + } diff --git a/toolkits/google_sheets/arcade_google_sheets/models.py b/toolkits/google_sheets/arcade_google_sheets/models.py new file mode 100644 index 00000000..d2ea5566 --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/models.py @@ -0,0 +1,241 @@ +import json +from typing import Optional + +from pydantic import BaseModel, field_validator, model_validator + +from arcade_google_sheets.enums import CellErrorType, NumberFormatType +from arcade_google_sheets.types import CellValue + + +class CellErrorValue(BaseModel): + """An error in a cell + + Implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/other#ErrorValue + """ + + type: CellErrorType + message: str + + +class CellExtendedValue(BaseModel): + """The kinds of value that a cell in a spreadsheet can have + + Implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/other#ExtendedValue + """ + + numberValue: float | None = None + stringValue: str | None = None + boolValue: bool | None = None + formulaValue: str | None = None + errorValue: Optional["CellErrorValue"] = None + + @model_validator(mode="after") + def check_exactly_one_value(cls, instance): # type: ignore[no-untyped-def] + provided = [v for v in instance.__dict__.values() if v is not None] + if len(provided) != 1: + raise ValueError( + "Exactly one of numberValue, stringValue, boolValue, " + "formulaValue, or errorValue must be set." + ) + return instance + + +class NumberFormat(BaseModel): + """The format of a number + + Implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/cells#NumberFormat + """ + + pattern: str + type: NumberFormatType + + +class CellFormat(BaseModel): + """The format of a cell + + Partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/cells#CellFormat + """ + + numberFormat: NumberFormat + + +class CellData(BaseModel): + """Data about a specific cell + + A partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/cells#CellData + """ + + userEnteredValue: CellExtendedValue + userEnteredFormat: CellFormat | None = None + + +class RowData(BaseModel): + """Data about each cellin a row + + A partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/sheets#RowData + """ + + values: list[CellData] + + +class GridData(BaseModel): + """Data in the grid + + A partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/sheets#GridData + """ + + startRow: int + startColumn: int + rowData: list[RowData] + + +class GridProperties(BaseModel): + """Properties of a grid + + A partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/sheets#GridProperties + """ + + rowCount: int + columnCount: int + + +class SheetProperties(BaseModel): + """Properties of a Sheet + + A partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/sheets#SheetProperties + """ + + sheetId: int + title: str + gridProperties: GridProperties | None = None + + +class Sheet(BaseModel): + """A Sheet in a spreadsheet + + A partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets/sheets#Sheet + """ + + properties: SheetProperties + data: list[GridData] | None = None + + +class SpreadsheetProperties(BaseModel): + """Properties of a spreadsheet + + A partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets#SpreadsheetProperties + """ + + title: str + + +class Spreadsheet(BaseModel): + """A spreadsheet + + A partial implementation of https://developers.google.com/sheets/api/reference/rest/v4/spreadsheets + """ + + properties: SpreadsheetProperties + sheets: list[Sheet] + + +class SheetDataInput(BaseModel): + """ + SheetDataInput models the cell data of a spreadsheet in a custom format. + + It is a dictionary mapping row numbers (as ints) to dictionaries that map + column letters (as uppercase strings) to cell values (int, float, str, or bool). + + This model enforces that: + - The outer keys are convertible to int. + - The inner keys are alphabetic strings (normalized to uppercase). + - All cell values are only of type int, float, str, or bool. + + The model automatically serializes (via `json_data()`) + and validates the inner types. + """ + + data: dict[int, dict[str, CellValue]] + + @classmethod + def _parse_json_if_string(cls, value): # type: ignore[no-untyped-def] + """Parses the value if it is a JSON string, otherwise returns it. + + Helper method for when validating the `data` field. + """ + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as e: + raise TypeError(f"Invalid JSON: {e}") + return value + + @classmethod + def _validate_row_key(cls, row_key) -> int: # type: ignore[no-untyped-def] + """Converts the row key to an integer, raising an error if conversion fails. + + Helper method for when validating the `data` field. + """ + try: + return int(row_key) + except (ValueError, TypeError): + raise TypeError(f"Row key '{row_key}' is not convertible to int.") + + @classmethod + def _validate_inner_cells(cls, cells, row_int: int) -> dict: # type: ignore[no-untyped-def] + """Validates that 'cells' is a dict mapping column letters to valid cell values + and normalizes the keys. + + Helper method for when validating the `data` field. + """ + if not isinstance(cells, dict): + raise TypeError( + f"Value for row '{row_int}' must be a dict mapping column letters to cell values." + ) + new_inner = {} + for col_key, cell_value in cells.items(): + if not isinstance(col_key, str): + raise TypeError(f"Column key '{col_key}' must be a string.") + col_string = col_key.upper() + if not col_string.isalpha(): + raise TypeError(f"Column key '{col_key}' is invalid. Must be alphabetic.") + if not isinstance(cell_value, int | float | str | bool): + raise TypeError( + f"Cell value for {col_string}{row_int} must be an int, float, str, or bool." + ) + new_inner[col_string] = cell_value + return new_inner + + @field_validator("data", mode="before") + @classmethod + def validate_and_convert_keys(cls, value): # type: ignore[no-untyped-def] + """ + Validates data when SheetDataInput is instantiated and converts it to the correct format. + Uses private helper methods to parse JSON, validate row keys, and validate inner cell data. + """ + if value is None: + return {} + + value = cls._parse_json_if_string(value) + if isinstance(value, dict): + new_value = {} + for row_key, cells in value.items(): + row_int = cls._validate_row_key(row_key) + inner_cells = cls._validate_inner_cells(cells, row_int) + new_value[row_int] = inner_cells + return new_value + + raise TypeError("data must be a dict or a valid JSON string representing a dict") + + def json_data(self) -> str: + """ + Serialize the sheet data to a JSON string. + """ + return json.dumps(self.data) + + @classmethod + def from_json(cls, json_str: str) -> "SheetDataInput": + """ + Create a SheetData instance from a JSON string. + """ + return cls.model_validate_json(json_str) diff --git a/toolkits/google_sheets/arcade_google_sheets/tools/__init__.py b/toolkits/google_sheets/arcade_google_sheets/tools/__init__.py new file mode 100644 index 00000000..e4158202 --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/tools/__init__.py @@ -0,0 +1,4 @@ +from arcade_google_sheets.tools.read import get_spreadsheet +from arcade_google_sheets.tools.write import create_spreadsheet, write_to_cell + +__all__ = ["create_spreadsheet", "get_spreadsheet", "write_to_cell"] diff --git a/toolkits/google_sheets/arcade_google_sheets/tools/read.py b/toolkits/google_sheets/arcade_google_sheets/tools/read.py new file mode 100644 index 00000000..baf00013 --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/tools/read.py @@ -0,0 +1,42 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, ToolMetadataKey, tool +from arcade_tdk.auth import Google + +from arcade_google_sheets.decorators import with_filepicker_fallback +from arcade_google_sheets.utils import ( + build_sheets_service, + parse_get_spreadsheet_response, +) + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/drive.file"], + ), + requires_metadata=[ToolMetadataKey.CLIENT_ID, ToolMetadataKey.COORDINATOR_URL], +) +@with_filepicker_fallback +async def get_spreadsheet( + context: ToolContext, + spreadsheet_id: Annotated[str, "The id of the spreadsheet to get"], +) -> Annotated[ + dict, + "The spreadsheet properties and data for all sheets in the spreadsheet", +]: + """ + Get the user entered values and formatted values for all cells in all sheets in the spreadsheet + along with the spreadsheet's properties + """ + service = build_sheets_service(context.get_auth_token_or_empty()) + + response = ( + service.spreadsheets() + .get( + spreadsheetId=spreadsheet_id, + includeGridData=True, + fields="spreadsheetId,spreadsheetUrl,properties/title,sheets/properties,sheets/data/rowData/values/userEnteredValue,sheets/data/rowData/values/formattedValue,sheets/data/rowData/values/effectiveValue", + ) + .execute() + ) + return parse_get_spreadsheet_response(response) diff --git a/toolkits/google_sheets/arcade_google_sheets/tools/write.py b/toolkits/google_sheets/arcade_google_sheets/tools/write.py new file mode 100644 index 00000000..30179b38 --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/tools/write.py @@ -0,0 +1,114 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Google +from arcade_tdk.errors import RetryableToolError + +from arcade_google_sheets.models import ( + SheetDataInput, + Spreadsheet, + SpreadsheetProperties, +) +from arcade_google_sheets.utils import ( + build_sheets_service, + create_sheet, + parse_write_to_cell_response, + validate_write_to_cell_params, +) + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/drive.file"], + ) +) +def create_spreadsheet( + context: ToolContext, + title: Annotated[str, "The title of the new spreadsheet"] = "Untitled spreadsheet", + data: Annotated[ + str | None, + "The data to write to the spreadsheet. A JSON string " + "(property names enclosed in double quotes) representing a dictionary that " + "maps row numbers to dictionaries that map column letters to cell values. " + "For example, data[23]['C'] would be the value of the cell in row 23, column C. " + "Type hint: dict[int, dict[str, Union[int, float, str, bool]]]", + ] = None, +) -> Annotated[dict, "The created spreadsheet's id and title"]: + """Create a new spreadsheet with the provided title and data in its first sheet + + Returns the newly created spreadsheet's id and title + """ + service = build_sheets_service(context.get_auth_token_or_empty()) + + try: + sheet_data = SheetDataInput(data=data) # type: ignore[arg-type] + except Exception as e: + msg = "Invalid JSON or unexpected data format for parameter `data`" + raise RetryableToolError( + message=msg, + additional_prompt_content=f"{msg}: {e}", + retry_after_ms=100, + ) + + spreadsheet = Spreadsheet( + properties=SpreadsheetProperties(title=title), + sheets=[create_sheet(sheet_data)], + ) + + body = spreadsheet.model_dump() + + response = ( + service.spreadsheets() + .create(body=body, fields="spreadsheetId,spreadsheetUrl,properties/title") + .execute() + ) + + return { + "title": response["properties"]["title"], + "spreadsheetId": response["spreadsheetId"], + "spreadsheetUrl": response["spreadsheetUrl"], + } + + +@tool( + requires_auth=Google( + scopes=["https://www.googleapis.com/auth/drive.file"], + ) +) +def write_to_cell( + context: ToolContext, + spreadsheet_id: Annotated[str, "The id of the spreadsheet to write to"], + column: Annotated[str, "The column string to write to. For example, 'A', 'F', or 'AZ'"], + row: Annotated[int, "The row number to write to"], + value: Annotated[str, "The value to write to the cell"], + sheet_name: Annotated[ + str, "The name of the sheet to write to. Defaults to 'Sheet1'" + ] = "Sheet1", +) -> Annotated[dict, "The status of the operation"]: + """ + Write a value to a single cell in a spreadsheet. + """ + service = build_sheets_service(context.get_auth_token_or_empty()) + validate_write_to_cell_params(service, spreadsheet_id, sheet_name, column, row) + + range_ = f"'{sheet_name}'!{column.upper()}{row}" + body = { + "range": range_, + "majorDimension": "ROWS", + "values": [[value]], + } + + sheet_properties = ( + service.spreadsheets() + .values() + .update( + spreadsheetId=spreadsheet_id, + range=range_, + valueInputOption="USER_ENTERED", + includeValuesInResponse=True, + body=body, + ) + .execute() + ) + + return parse_write_to_cell_response(sheet_properties) diff --git a/toolkits/google_sheets/arcade_google_sheets/types.py b/toolkits/google_sheets/arcade_google_sheets/types.py new file mode 100644 index 00000000..f42a5061 --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/types.py @@ -0,0 +1 @@ +CellValue = int | float | str | bool diff --git a/toolkits/google_sheets/arcade_google_sheets/utils.py b/toolkits/google_sheets/arcade_google_sheets/utils.py new file mode 100644 index 00000000..8495a029 --- /dev/null +++ b/toolkits/google_sheets/arcade_google_sheets/utils.py @@ -0,0 +1,548 @@ +import logging +from typing import Any + +from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from google.oauth2.credentials import Credentials +from googleapiclient.discovery import Resource, build + +from arcade_google_sheets.constants import ( + DEFAULT_SHEET_COLUMN_COUNT, + DEFAULT_SHEET_ROW_COUNT, +) +from arcade_google_sheets.enums import NumberFormatType +from arcade_google_sheets.models import ( + CellData, + CellExtendedValue, + CellFormat, + GridData, + GridProperties, + NumberFormat, + RowData, + Sheet, + SheetDataInput, + SheetProperties, +) +from arcade_google_sheets.types import CellValue + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +def build_sheets_service(auth_token: str | None) -> Resource: # type: ignore[no-any-unimported] + """ + Build a Sheets service object. + """ + auth_token = auth_token or "" + return build("sheets", "v4", credentials=Credentials(auth_token)) + + +def col_to_index(col: str) -> int: + """Convert a sheet's column string to a 0-indexed column index + + Args: + col (str): The column string to convert. e.g., "A", "AZ", "QED" + + Returns: + int: The 0-indexed column index. + """ + result = 0 + for char in col.upper(): + result = result * 26 + (ord(char) - ord("A") + 1) + return result - 1 + + +def index_to_col(index: int) -> str: + """Convert a 0-indexed column index to its corresponding column string + + Args: + index (int): The 0-indexed column index to convert. + + Returns: + str: The column string. e.g., "A", "AZ", "QED" + """ + result = "" + index += 1 + while index > 0: + index, rem = divmod(index - 1, 26) + result = chr(rem + ord("A")) + result + return result + + +def is_col_greater(col1: str, col2: str) -> bool: + """Determine if col1 represents a column that comes after col2 in a sheet + + This comparison is based on: + 1. The length of the column string (longer means greater). + 2. Lexicographical comparison if both strings are the same length. + + Args: + col1 (str): The first column string to compare. + col2 (str): The second column string to compare. + + Returns: + bool: True if col1 comes after col2, False otherwise. + """ + if len(col1) != len(col2): + return len(col1) > len(col2) + return col1.upper() > col2.upper() + + +def compute_sheet_data_dimensions( + sheet_data_input: SheetDataInput, +) -> tuple[tuple[int, int], tuple[int, int]]: + """ + Compute the dimensions of a sheet based on the data provided. + + Args: + sheet_data_input (SheetDataInput): + The data to compute the dimensions of. + + Returns: + tuple[tuple[int, int], tuple[int, int]]: The dimensions of the sheet. The first tuple + contains the row range (start, end) and the second tuple contains the column range + (start, end). + """ + max_row = 0 + min_row = 10_000_000 # max number of cells in a sheet + max_col_str = None + min_col_str = None + + for key, row in sheet_data_input.data.items(): + try: + row_num = int(key) + except ValueError: + continue + if row_num > max_row: + max_row = row_num + if row_num < min_row: + min_row = row_num + + if isinstance(row, dict): + for col in row: + # Update max column string + if max_col_str is None or is_col_greater(col, max_col_str): + max_col_str = col + # Update min column string + if min_col_str is None or is_col_greater(min_col_str, col): + min_col_str = col + + max_col_index = col_to_index(max_col_str) if max_col_str is not None else -1 + min_col_index = col_to_index(min_col_str) if min_col_str is not None else 0 + + return (min_row, max_row), (min_col_index, max_col_index) + + +def create_sheet(sheet_data_input: SheetDataInput) -> Sheet: + """Create a Google Sheet from a dictionary of data. + + Args: + sheet_data_input (SheetDataInput): The data to create the sheet from. + + Returns: + Sheet: The created sheet. + """ + (_, max_row), (min_col_index, max_col_index) = compute_sheet_data_dimensions(sheet_data_input) + sheet_data = create_sheet_data(sheet_data_input, min_col_index, max_col_index) + sheet_properties = create_sheet_properties( + row_count=max(DEFAULT_SHEET_ROW_COUNT, max_row), + column_count=max(DEFAULT_SHEET_COLUMN_COUNT, max_col_index + 1), + ) + + return Sheet(properties=sheet_properties, data=sheet_data) + + +def create_sheet_properties( + sheet_id: int = 1, + title: str = "Sheet1", + row_count: int = DEFAULT_SHEET_ROW_COUNT, + column_count: int = DEFAULT_SHEET_COLUMN_COUNT, +) -> SheetProperties: + """Create a SheetProperties object + + Args: + sheet_id (int): The ID of the sheet. + title (str): The title of the sheet. + row_count (int): The number of rows in the sheet. + column_count (int): The number of columns in the sheet. + + Returns: + SheetProperties: The created sheet properties object. + """ + return SheetProperties( + sheetId=sheet_id, + title=title, + gridProperties=GridProperties(rowCount=row_count, columnCount=column_count), + ) + + +def group_contiguous_rows(row_numbers: list[int]) -> list[list[int]]: + """Groups a sorted list of row numbers into contiguous groups + + A contiguous group is a list of row numbers that are consecutive integers. + For example, [1,2,3,5,6] is converted to [[1,2,3],[5,6]]. + + Args: + row_numbers (list[int]): The list of row numbers to group. + + Returns: + list[list[int]]: The grouped row numbers. + """ + if not row_numbers: + return [] + groups = [] + current_group = [row_numbers[0]] + for r in row_numbers[1:]: + if r == current_group[-1] + 1: + current_group.append(r) + else: + groups.append(current_group) + current_group = [r] + groups.append(current_group) + return groups + + +def create_cell_data(cell_value: CellValue) -> CellData: + """ + Create a CellData object based on the type of cell_value. + """ + if isinstance(cell_value, bool): + return _create_bool_cell(cell_value) + elif isinstance(cell_value, int): + return _create_int_cell(cell_value) + elif isinstance(cell_value, float): + return _create_float_cell(cell_value) + elif isinstance(cell_value, str): + return _create_string_cell(cell_value) + + +def _create_formula_cell(cell_value: str) -> CellData: + cell_val = CellExtendedValue(formulaValue=cell_value) + return CellData(userEnteredValue=cell_val) + + +def _create_currency_cell(cell_value: str) -> CellData: + value_without_symbol = cell_value[1:] + try: + num_value = int(value_without_symbol) + cell_format = CellFormat( + numberFormat=NumberFormat(type=NumberFormatType.CURRENCY, pattern="$#,##0") + ) + cell_val = CellExtendedValue(numberValue=num_value) + return CellData(userEnteredValue=cell_val, userEnteredFormat=cell_format) + except ValueError: + try: + num_value = float(value_without_symbol) # type: ignore[assignment] + cell_format = CellFormat( + numberFormat=NumberFormat(type=NumberFormatType.CURRENCY, pattern="$#,##0.00") + ) + cell_val = CellExtendedValue(numberValue=num_value) + return CellData(userEnteredValue=cell_val, userEnteredFormat=cell_format) + except ValueError: + return CellData(userEnteredValue=CellExtendedValue(stringValue=cell_value)) + + +def _create_percent_cell(cell_value: str) -> CellData: + try: + num_value = float(cell_value[:-1].strip()) + cell_format = CellFormat( + numberFormat=NumberFormat(type=NumberFormatType.PERCENT, pattern="0.00%") + ) + cell_val = CellExtendedValue(numberValue=num_value) + return CellData(userEnteredValue=cell_val, userEnteredFormat=cell_format) + except ValueError: + return CellData(userEnteredValue=CellExtendedValue(stringValue=cell_value)) + + +def _create_bool_cell(cell_value: bool) -> CellData: + return CellData(userEnteredValue=CellExtendedValue(boolValue=cell_value)) + + +def _create_int_cell(cell_value: int) -> CellData: + cell_format = CellFormat( + numberFormat=NumberFormat(type=NumberFormatType.NUMBER, pattern="#,##0") + ) + return CellData( + userEnteredValue=CellExtendedValue(numberValue=cell_value), userEnteredFormat=cell_format + ) + + +def _create_float_cell(cell_value: float) -> CellData: + cell_format = CellFormat( + numberFormat=NumberFormat(type=NumberFormatType.NUMBER, pattern="#,##0.00") + ) + return CellData( + userEnteredValue=CellExtendedValue(numberValue=cell_value), userEnteredFormat=cell_format + ) + + +def _create_string_cell(cell_value: str) -> CellData: + if cell_value.startswith("="): + return _create_formula_cell(cell_value) + elif cell_value.startswith("$") and len(cell_value) > 1: + return _create_currency_cell(cell_value) + elif cell_value.endswith("%") and len(cell_value) > 1: + return _create_percent_cell(cell_value) + + return CellData(userEnteredValue=CellExtendedValue(stringValue=cell_value)) + + +def create_row_data( + row_data: dict[str, CellValue], min_col_index: int, max_col_index: int +) -> RowData: + """Constructs RowData for a single row using the provided row_data. + + Args: + row_data (dict[str, CellValue]): The data to create the row from. + min_col_index (int): The minimum column index from the SheetDataInput. + max_col_index (int): The maximum column index from the SheetDataInput. + """ + row_cells = [] + for col_idx in range(min_col_index, max_col_index + 1): + col_letter = index_to_col(col_idx) + if col_letter in row_data: + cell_data = create_cell_data(row_data[col_letter]) + else: + cell_data = CellData(userEnteredValue=CellExtendedValue(stringValue="")) + row_cells.append(cell_data) + return RowData(values=row_cells) + + +def create_sheet_data( + sheet_data_input: SheetDataInput, + min_col_index: int, + max_col_index: int, +) -> list[GridData]: + """Create grid data from SheetDataInput by grouping contiguous rows and processing cells. + + Args: + sheet_data_input (SheetDataInput): The data to create the sheet from. + min_col_index (int): The minimum column index from the SheetDataInput. + max_col_index (int): The maximum column index from the SheetDataInput. + + Returns: + list[GridData]: The created grid data. + """ + row_numbers = list(sheet_data_input.data.keys()) + if not row_numbers: + return [] + + sorted_rows = sorted(row_numbers) + groups = group_contiguous_rows(sorted_rows) + + sheet_data = [] + for group in groups: + rows_data = [] + for r in group: + current_row_data = sheet_data_input.data.get(r, {}) + row = create_row_data(current_row_data, min_col_index, max_col_index) + rows_data.append(row) + grid_data = GridData( + startRow=group[0] - 1, # convert to 0-indexed + startColumn=min_col_index, + rowData=rows_data, + ) + sheet_data.append(grid_data) + + return sheet_data + + +def parse_get_spreadsheet_response(api_response: dict) -> dict: + """ + Parse the get spreadsheet Google Sheets API response into a structured dictionary. + """ + properties = api_response.get("properties", {}) + sheets = [parse_sheet(sheet) for sheet in api_response.get("sheets", [])] + + return { + "title": properties.get("title", ""), + "spreadsheetId": api_response.get("spreadsheetId", ""), + "spreadsheetUrl": api_response.get("spreadsheetUrl", ""), + "sheets": sheets, + } + + +def parse_sheet(api_sheet: dict) -> dict: + """ + Parse an individual sheet's data from the Google Sheets 'get spreadsheet' + API response into a structured dictionary. + """ + props = api_sheet.get("properties", {}) + grid_props = props.get("gridProperties", {}) + cell_data = convert_api_grid_data_to_dict(api_sheet.get("data", [])) + + return { + "sheetId": props.get("sheetId"), + "title": props.get("title", ""), + "rowCount": grid_props.get("rowCount", 0), + "columnCount": grid_props.get("columnCount", 0), + "data": cell_data, + } + + +def extract_user_entered_cell_value(cell: dict) -> Any: + """ + Extract the user entered value from a cell's 'userEnteredValue'. + + Args: + cell (dict): A cell dictionary from the grid data. + + Returns: + The extracted value if present, otherwise None. + """ + user_val = cell.get("userEnteredValue", {}) + for key in ["stringValue", "numberValue", "boolValue", "formulaValue"]: + if key in user_val: + return user_val[key] + + return "" + + +def process_row(row: dict, start_column_index: int) -> dict: + """ + Process a single row from grid data, converting non-empty cells into a dictionary + that maps column letters to cell values. + + Args: + row (dict): A row from the grid data. + start_column_index (int): The starting column index for this row. + + Returns: + dict: A mapping of column letters to cell values for non-empty cells. + """ + row_result = {} + for j, cell in enumerate(row.get("values", [])): + column_index = start_column_index + j + column_string = index_to_col(column_index) + user_entered_cell_value = extract_user_entered_cell_value(cell) + formatted_cell_value = cell.get("formattedValue", "") + + if user_entered_cell_value != "" or formatted_cell_value != "": + row_result[column_string] = { + "userEnteredValue": user_entered_cell_value, + "formattedValue": formatted_cell_value, + } + + return row_result + + +def convert_api_grid_data_to_dict(grids: list[dict]) -> dict: + """ + Convert a list of grid data dictionaries from the 'get spreadsheet' API + response into a structured cell dictionary. + + The returned dictionary maps row numbers to sub-dictionaries that map column letters + (e.g., 'A', 'B', etc.) to their corresponding non-empty cell values. + + Args: + grids (list[dict]): The list of grid data dictionaries from the API. + + Returns: + dict: A dictionary mapping row numbers to dictionaries of column letter/value pairs. + Only includes non-empty rows and non-empty cells. + """ + result = {} + for grid in grids: + start_row = grid.get("startRow", 0) + start_column = grid.get("startColumn", 0) + + for i, row in enumerate(grid.get("rowData", []), start=1): + current_row = start_row + i + row_data = process_row(row, start_column) + + if row_data: + result[current_row] = row_data + + return dict(sorted(result.items())) + + +def validate_write_to_cell_params( # type: ignore[no-any-unimported] + service: Resource, + spreadsheet_id: str, + sheet_name: str, + column: str, + row: int, +) -> None: + """Validates the input parameters for the write to cell tool. + + Args: + service (Resource): The Google Sheets service. + spreadsheet_id (str): The ID of the spreadsheet provided to the tool. + sheet_name (str): The name of the sheet provided to the tool. + column (str): The column to write to provided to the tool. + row (int): The row to write to provided to the tool. + + Raises: + RetryableToolError: + If the sheet name is not found in the spreadsheet + ToolExecutionError: + If the column is not alphabetical + If the row is not a positive number + If the row is out of bounds for the sheet + If the column is out of bounds for the sheet + """ + if not column.isalpha(): + raise ToolExecutionError( + message=( + f"Invalid column name {column}. " + "It must be a non-empty string containing only letters" + ), + ) + + if row < 1: + raise ToolExecutionError( + message=(f"Invalid row number {row}. It must be a positive integer greater than 0."), + ) + + sheet_properties = ( + service.spreadsheets() + .get( + spreadsheetId=spreadsheet_id, + includeGridData=True, + fields="sheets/properties/title,sheets/properties/gridProperties/rowCount,sheets/properties/gridProperties/columnCount", + ) + .execute() + ) + sheet_names = [sheet["properties"]["title"] for sheet in sheet_properties["sheets"]] + sheet_row_count = sheet_properties["sheets"][0]["properties"]["gridProperties"]["rowCount"] + sheet_column_count = sheet_properties["sheets"][0]["properties"]["gridProperties"][ + "columnCount" + ] + + if sheet_name not in sheet_names: + raise RetryableToolError( + message=f"Sheet name {sheet_name} not found in spreadsheet with id {spreadsheet_id}", + additional_prompt_content=f"Sheet names in the spreadsheet: {sheet_names}", + retry_after_ms=100, + ) + + if row > sheet_row_count: + raise ToolExecutionError( + message=( + f"Row {row} is out of bounds for sheet {sheet_name} " + f"in spreadsheet with id {spreadsheet_id}. " + f"Sheet only has {sheet_row_count} rows which is less than the requested row {row}" + ) + ) + + if col_to_index(column) > sheet_column_count: + raise ToolExecutionError( + message=( + f"Column {column} is out of bounds for sheet {sheet_name} " + f"in spreadsheet with id {spreadsheet_id}. " + f"Sheet only has {sheet_column_count} columns which " + f"is less than the requested column {column}" + ) + ) + + +def parse_write_to_cell_response(response: dict) -> dict: + return { + "spreadsheetId": response["spreadsheetId"], + "sheetTitle": response["updatedData"]["range"].split("!")[0], + "updatedCell": response["updatedData"]["range"].split("!")[1], + "value": response["updatedData"]["values"][0][0], + } diff --git a/toolkits/google_sheets/evals/eval_google_sheets.py b/toolkits/google_sheets/evals/eval_google_sheets.py new file mode 100644 index 00000000..5312a0b6 --- /dev/null +++ b/toolkits/google_sheets/evals/eval_google_sheets.py @@ -0,0 +1,169 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_google_sheets +from arcade_google_sheets.tools import ( + create_spreadsheet, + get_spreadsheet, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_module(arcade_google_sheets) + +sheet_content_prompt = """name age email score gender city country registration_date +John Doe 28 johndoe@example.com 85 Male New York USA 2023-01-15 +Jane Smith 34 janesmith@example.com 92 Female Los Angeles USA 2023-02-20 +Alice Johnson 22 alicej@example.com 78 Female Chicago USA 2023-03-10 +Bob Brown 45 bobbrown@example.com 88 Male Houston USA 2023-04-05 +Charlie Davis 30 charlied@example.com 95 Male Phoenix USA 2023-05-12 +Eve White 27 evewhite@example.com 82 Female Philadelphia USA 2023-06-18 +Frank Black 40 frankb@example.com 90 Male San Antonio USA 2023-07-25 +Grace Green 29 graceg@example.com 76 Female Dallas USA 2023-08-30 +Hank Blue 35 hankb@example.com 89 Male San Diego USA 2023-09-15 +Ivy Red 31 ivyred@example.com 91 Female San Jose USA 2023-10-01 +Michael Grey 33 michaelg@example.com 87 Male Seattle USA 2023-10-05 +Nina Black 26 ninab@example.com 84 Female Miami USA 2023-10-10 +Oscar White 38 oscarw@example.com 90 Male Atlanta USA 2023-10-15 +Paula Green 32 paulag@example.com 93 Female Boston USA 2023-10-20 +Quentin Brown 29 quentinb@example.com 81 Male Denver USA 2023-10-25 +Rachel Blue 24 rachelb@example.com 79 Female Orlando USA 2023-10-30 +Steve Red 36 stever@example.com 88 Male Las Vegas USA 2023-11-01 +Tina Yellow 30 tinay@example.com 85 Female Portland USA 2023-11-05 +Ursula Pink 27 ursulap@example.com 82 Female San Francisco USA 2023-11-10 +Victor Grey 41 victorg@example.com 91 Male Charlotte USA 2023-11-15 +Wendy Black 34 wendyb@example.com 89 Female Detroit USA 2023-11-20 +Xander White 29 xanderw@example.com 86 Male Indianapolis USA 2023-11-25 +Yvonne Green 25 yvonnag@example.com 83 Female Columbus USA 2023-11-30 +Zachary Blue 37 zacharyb@example.com 90 Male Jacksonville USA 2023-12-01 +Alice Brown 28 aliceb@example.com 80 Female Memphis USA 2023-12-05 +Brian Black 39 brianb@example.com 92 Male Nashville USA 2023-12-10 +Cathy Green 31 cathyg@example.com 84 Female Virginia Beach USA 2023-12-15 +Daniel White 30 danielw@example.com 88 Male Atlanta USA 2023-12-20 +Eva Red 26 evar@example.com 81 Female New Orleans USA 2023-12-25 +Frankie Grey 35 frankieg@example.com 90 Male San Antonio USA 2023-12-30 +Gina Blue 29 ginab@example.com 87 Female San Diego USA 2024-01-01 +Henry Black 42 henryb@example.com 93 Male Philadelphia USA 2024-01-05 +Isla Green 24 islag@example.com 79 Female Chicago USA 2024-01-10 +Jack White 33 jackw@example.com 85 Male Los Angeles USA 2024-01-15 +Kathy Red 31 kathyr@example.com 82 Female Miami USA 2024-01-20 +Liam Grey 36 liamg@example.com 89 Male Seattle USA 2024-01-25 +Mia Black 27 miab@example.com 80 Female Denver USA 2024-01-30 +Nate Green 30 nateg@example.com 88 Male Orlando USA 2024-02-01 +- (empty row) +- (empty row) +- (empty row) +100, 300, 234, 399, 5039, 2345, 23526, 123, 54, 234, 54, 23, 12, 57, 1324, (the formula for sum of everything to the left) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +- (empty row) +456, 234, 234, 399, 234, 1234, 23526, 123, 54, 234, 4567, 23, 12, 234, 1324, (the formula for sum of everything to the left) +""" + + +@tool_eval() +def create_spreadsheet_eval() -> EvalSuite: + """Create an evaluation suite for Google Sheets create_spreadsheet tool.""" + + sheet_content_expected1 = """{"1": {"A": "name", "B": "age", "C": "email", "D": "score", "E": "gender", "F": "city", "G": "country", "H": "registration_date"}, "2": {"A": "John Doe", "B": 28, "C": "johndoe@example.com", "D": 85, "E": "Male", "F": "New York", "G": "USA", "H": "2023-01-15"}, "3": {"A": "Jane Smith", "B": 34, "C": "janesmith@example.com", "D": 92, "E": "Female", "F": "Los Angeles", "G": "USA", "H": "2023-02-20"}, "4": {"A": "Alice Johnson", "B": 22, "C": "alicej@example.com", "D": 78, "E": "Female", "F": "Chicago", "G": "USA", "H": "2023-03-10"}, "5": {"A": "Bob Brown", "B": 45, "C": "bobbrown@example.com", "D": 88, "E": "Male", "F": "Houston", "G": "USA", "H": "2023-04-05"}, "6": {"A": "Charlie Davis", "B": 30, "C": "charlied@example.com", "D": 95, "E": "Male", "F": "Phoenix", "G": "USA", "H": "2023-05-12"}, "7": {"A": "Eve White", "B": 27, "C": "evewhite@example.com", "D": 82, "E": "Female", "F": "Philadelphia", "G": "USA", "H": "2023-06-18"}, "8": {"A": "Frank Black", "B": 40, "C": "frankb@example.com", "D": 90, "E": "Male", "F": "San Antonio", "G": "USA", "H": "2023-07-25"}, "9": {"A": "Grace Green", "B": 29, "C": "graceg@example.com", "D": 76, "E": "Female", "F": "Dallas", "G": "USA", "H": "2023-08-30"}, "10": {"A": "Hank Blue", "B": 35, "C": "hankb@example.com", "D": 89, "E": "Male", "F": "San Diego", "G": "USA", "H": "2023-09-15"}, "11": {"A": "Ivy Red", "B": 31, "C": "ivyred@example.com", "D": 91, "E": "Female", "F": "San Jose", "G": "USA", "H": "2023-10-01"}, "12": {"A": "Michael Grey", "B": 33, "C": "michaelg@example.com", "D": 87, "E": "Male", "F": "Seattle", "G": "USA", "H": "2023-10-05"}, "13": {"A": "Nina Black", "B": 26, "C": "ninab@example.com", "D": 84, "E": "Female", "F": "Miami", "G": "USA", "H": "2023-10-10"}, "14": {"A": "Oscar White", "B": 38, "C": "oscarw@example.com", "D": 90, "E": "Male", "F": "Atlanta", "G": "USA", "H": "2023-10-15"}, "15": {"A": "Paula Green", "B": 32, "C": "paulag@example.com", "D": 93, "E": "Female", "F": "Boston", "G": "USA", "H": "2023-10-20"}, "16": {"A": "Quentin Brown", "B": 29, "C": "quentinb@example.com", "D": 81, "E": "Male", "F": "Denver", "G": "USA", "H": "2023-10-25"}, "17": {"A": "Rachel Blue", "B": 24, "C": "rachelb@example.com", "D": 79, "E": "Female", "F": "Orlando", "G": "USA", "H": "2023-10-30"}, "18": {"A": "Steve Red", "B": 36, "C": "stever@example.com", "D": 88, "E": "Male", "F": "Las Vegas", "G": "USA", "H": "2023-11-01"}, "19": {"A": "Tina Yellow", "B": 30, "C": "tinay@example.com", "D": 85, "E": "Female", "F": "Portland", "G": "USA", "H": "2023-11-05"}, "20": {"A": "Ursula Pink", "B": 27, "C": "ursulap@example.com", "D": 82, "E": "Female", "F": "San Francisco", "G": "USA", "H": "2023-11-10"}, "21": {"A": "Victor Grey", "B": 41, "C": "victorg@example.com", "D": 91, "E": "Male", "F": "Charlotte", "G": "USA", "H": "2023-11-15"}, "22": {"A": "Wendy Black", "B": 34, "C": "wendyb@example.com", "D": 89, "E": "Female", "F": "Detroit", "G": "USA", "H": "2023-11-20"}, "23": {"A": "Xander White", "B": 29, "C": "xanderw@example.com", "D": 86, "E": "Male", "F": "Indianapolis", "G": "USA", "H": "2023-11-25"}, "24": {"A": "Yvonne Green", "B": 25, "C": "yvonnag@example.com", "D": 83, "E": "Female", "F": "Columbus", "G": "USA", "H": "2023-11-30"}, "25": {"A": "Zachary Blue", "B": 37, "C": "zacharyb@example.com", "D": 90, "E": "Male", "F": "Jacksonville", "G": "USA", "H": "2023-12-01"}, "26": {"A": "Alice Brown", "B": 28, "C": "aliceb@example.com", "D": 80, "E": "Female", "F": "Memphis", "G": "USA", "H": "2023-12-05"}, "27": {"A": "Brian Black", "B": 39, "C": "brianb@example.com", "D": 92, "E": "Male", "F": "Nashville", "G": "USA", "H": "2023-12-10"}, "28": {"A": "Cathy Green", "B": 31, "C": "cathyg@example.com", "D": 84, "E": "Female", "F": "Virginia Beach", "G": "USA", "H": "2023-12-15"}, "29": {"A": "Daniel White", "B": 30, "C": "danielw@example.com", "D": 88, "E": "Male", "F": "Atlanta", "G": "USA", "H": "2023-12-20"}, "30": {"A": "Eva Red", "B": 26, "C": "evar@example.com", "D": 81, "E": "Female", "F": "New Orleans", "G": "USA", "H": "2023-12-25"}, "31": {"A": "Frankie Grey", "B": 35, "C": "frankieg@example.com", "D": 90, "E": "Male", "F": "San Antonio", "G": "USA", "H": "2023-12-30"}, "32": {"A": "Gina Blue", "B": 29, "C": "ginab@example.com", "D": 87, "E": "Female", "F": "San Diego", "G": "USA", "H": "2024-01-01"}, "33": {"A": "Henry Black", "B": 42, "C": "henryb@example.com", "D": 93, "E": "Male", "F": "Philadelphia", "G": "USA", "H": "2024-01-05"}, "34": {"A": "Isla Green", "B": 24, "C": "islag@example.com", "D": 79, "E": "Female", "F": "Chicago", "G": "USA", "H": "2024-01-10"}, "35": {"A": "Jack White", "B": 33, "C": "jackw@example.com", "D": 85, "E": "Male", "F": "Los Angeles", "G": "USA", "H": "2024-01-15"}, "36": {"A": "Kathy Red", "B": 31, "C": "kathyr@example.com", "D": 82, "E": "Female", "F": "Miami", "G": "USA", "H": "2024-01-20"}, "37": {"A": "Liam Grey", "B": 36, "C": "liamg@example.com", "D": 89, "E": "Male", "F": "Seattle", "G": "USA", "H": "2024-01-25"}, "38": {"A": "Mia Black", "B": 27, "C": "miab@example.com", "D": 80, "E": "Female", "F": "Denver", "G": "USA", "H": "2024-01-30"}, "39": {"A": "Nate Green", "B": 30, "C": "nateg@example.com", "D": 88, "E": "Male", "F": "Orlando", "G": "USA", "H": "2024-02-01"}, "40": {}, "41": {}, "42": {}, "43": {"A": 100, "B": 300, "C": 234, "D": 399, "E": 5039, "F": 2345, "G": 23526, "H": 123, "I": 54, "J": 234, "K": 54, "L": 23, "M": 12, "N": 57, "O": 1324, "P": "(the formula for sum of everything to the left)"}, "44": {}, "45": {}, "46": {}, "47": {}, "48": {}, "49": {}, "50": {}, "51": {}, "52": {}, "53": {}, "54": {}, "55": {}, "56": {}, "57": {}, "58": {}, "59": {}, "60": {"A": 456, "B": 234, "C": 234, "D": 399, "E": 234, "F": 1234, "G": 23526, "H": 123, "I": 54, "J": 234, "K": 4567, "L": 899, "M": 12, "N": 234, "O": 45, "P": "(the formula for sum of everything to the left)"}}""" + sheet_content_sparse_expected = """{"1": {"AA": "=SUM(A1,A2,A3)", "3782": {"A": 3783, "D": 3784, "AAZ": 3785, "ZZFS": 3786, "CA": 3787}}}""" + + suite = EvalSuite( + name="Google Sheets Tools Evaluation", + system_message="You are an AI assistant that can manage Google Sheets using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Create a spreadsheet from large data payload", + user_message=f"Create a spreadsheet named 'Data' with the following content:\n{sheet_content_prompt}", + expected_tool_calls=[ + ExpectedToolCall( + func=create_spreadsheet, + args={ + "title": "Data", + "data": sheet_content_expected1, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="title", weight=0.1), + SimilarityCritic(critic_field="data", weight=0.9, similarity_threshold=0.99), + ], + ) + + suite.add_case( + name="Create a spreadsheet from sparse data payload", + user_message="Create a spreadsheet named 'Sparse Data' that fills the 27th column in the first row with the formula that sums A1, A2, and A3 cells. The 3782nd row should have its A, D, AAZ, ZZFS, and CA columns filled with the numbers 1, 2, 3, 4, and 5, respectively, summed with its row number.", + expected_tool_calls=[ + ExpectedToolCall( + func=create_spreadsheet, + args={ + "title": "Sparse Data", + "data": sheet_content_sparse_expected, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="title", weight=0.1), + SimilarityCritic(critic_field="data", weight=0.9, similarity_threshold=0.95), + ], + ) + + return suite + + +@tool_eval() +def get_spreadsheet_eval() -> EvalSuite: + """Create an evaluation suite for Google Sheets get_spreadsheet tool.""" + + suite = EvalSuite( + name="Google Sheets Tools Evaluation", + system_message="You are an AI assistant that can manage Google Sheets using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Get a spreadsheet", + user_message="Get the data in the second sheet of the spreadsheet with the following id 1L2ovCUcRNOacoWxtLV3jgaidWZq4Bw_WXbIWJcxobN0", + expected_tool_calls=[ + ExpectedToolCall( + func=get_spreadsheet, + args={ + "spreadsheet_id": "1L2ovCUcRNOacoWxtLV3jgaidWZq4Bw_WXbIWJcxobN0", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="spreadsheet_id", weight=1.0), + ], + ) + + return suite diff --git a/toolkits/google_sheets/pyproject.toml b/toolkits/google_sheets/pyproject.toml new file mode 100644 index 00000000..bf1a6515 --- /dev/null +++ b/toolkits/google_sheets/pyproject.toml @@ -0,0 +1,63 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_sheets" +version = "2.0.0" +description = "Arcade.dev LLM tools for Google Sheets" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "google-api-python-client>=2.137.0,<3.0.0", + "google-api-core>=2.19.1,<3.0.0", + "google-auth>=2.32.0,<3.0.0", + "google-auth-httplib2>=0.2.0,<1.0.0", + "googleapis-common-protos>=1.63.2,<2.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0rc1,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_sheets/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_sheets",] diff --git a/toolkits/google_sheets/tests/__init__.py b/toolkits/google_sheets/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/google_sheets/tests/test_sheets_models.py b/toolkits/google_sheets/tests/test_sheets_models.py new file mode 100644 index 00000000..428e8b66 --- /dev/null +++ b/toolkits/google_sheets/tests/test_sheets_models.py @@ -0,0 +1,84 @@ +from arcade_google_sheets.models import SheetDataInput + + +def test_sheet_input_data_init(): + data = '{"1":{"A":"name","B":"age","C":"email","D":"score","E":"gender","F":"city","G":"country","H":"registration_date"},"34":{"A":"Isla Green","B":24,"C":"islag@example.com","D":79,"E":"Female","F":"Chicago","G":"USA","H":"2024-01-10"},"38":{"A":"Mia Black","B":27,"C":"miab@example.com","D":80,"E":"Female","F":"Denver","G":"USA","H":"2024-01-30"},"39":{"A":"Nate Green","B":30,"C":"nateg@example.com","D":88,"E":"Male","F":"Orlando","G":"USA","H":"2024-02-01"},"43":{"A":100,"B":300,"C":234,"D":399,"E":5039,"F":2345,"G":23526,"H":123,"I":54,"J":234,"K":54,"L":23,"M":12,"N":57,"O":1324},"47":{"A":456,"B":234,"C":234,"D":399,"E":234,"F":1234,"G":23526,"H":123,"I":54,"J":234,"K":4567,"L":23,"M":12,"N":234,"O":1324}}' + expected_data = { + 1: { + "A": "name", + "B": "age", + "C": "email", + "D": "score", + "E": "gender", + "F": "city", + "G": "country", + "H": "registration_date", + }, + 34: { + "A": "Isla Green", + "B": 24, + "C": "islag@example.com", + "D": 79, + "E": "Female", + "F": "Chicago", + "G": "USA", + "H": "2024-01-10", + }, + 38: { + "A": "Mia Black", + "B": 27, + "C": "miab@example.com", + "D": 80, + "E": "Female", + "F": "Denver", + "G": "USA", + "H": "2024-01-30", + }, + 39: { + "A": "Nate Green", + "B": 30, + "C": "nateg@example.com", + "D": 88, + "E": "Male", + "F": "Orlando", + "G": "USA", + "H": "2024-02-01", + }, + 43: { + "A": 100, + "B": 300, + "C": 234, + "D": 399, + "E": 5039, + "F": 2345, + "G": 23526, + "H": 123, + "I": 54, + "J": 234, + "K": 54, + "L": 23, + "M": 12, + "N": 57, + "O": 1324, + }, + 47: { + "A": 456, + "B": 234, + "C": 234, + "D": 399, + "E": 234, + "F": 1234, + "G": 23526, + "H": 123, + "I": 54, + "J": 234, + "K": 4567, + "L": 23, + "M": 12, + "N": 234, + "O": 1324, + }, + } + + sheet_input_data = SheetDataInput(data=data) + assert sheet_input_data.data == expected_data diff --git a/toolkits/google_sheets/tests/test_sheets_utils.py b/toolkits/google_sheets/tests/test_sheets_utils.py new file mode 100644 index 00000000..d8ad179a --- /dev/null +++ b/toolkits/google_sheets/tests/test_sheets_utils.py @@ -0,0 +1,542 @@ +from unittest.mock import MagicMock, patch + +import pytest +from arcade_tdk.errors import RetryableToolError, ToolExecutionError + +from arcade_google_sheets.enums import NumberFormatType +from arcade_google_sheets.models import ( + CellData, + CellExtendedValue, + RowData, + SheetDataInput, +) +from arcade_google_sheets.utils import ( + col_to_index, + compute_sheet_data_dimensions, + convert_api_grid_data_to_dict, + create_cell_data, + create_row_data, + create_sheet_data, + create_sheet_properties, + extract_user_entered_cell_value, + group_contiguous_rows, + index_to_col, + is_col_greater, + process_row, + validate_write_to_cell_params, +) + + +@pytest.fixture +def sheet_data_input_fixture(): + data = { + 1: { + "A": "name", + "B": "age", + "C": "email", + "D": "score", + "E": "gender", + "F": "city", + "G": "country", + "H": "registration_date", + }, + 2: { + "A": "John Doe", + "B": 28, + "C": "johndoe@example.com", + "D": 85.4, + "E": "Male", + "F": "New York", + "G": "USA", + "H": "2023-01-15", + }, + 10: { + "A": "Nate Green", + "B": 30, + "C": "nateg@example.com", + "D": 88, + "E": "Male", + "F": "Orlando", + "G": "USA", + "H": "2024-02-01", + }, + 43: { + "A": 100, + "B": 300, + "H": 123, + "I": "=SUM(SEQUENCE(10))", + }, + 44: { + "A": 456, + "B": 234, + "H": 123, + "I": "=SUM(SEQUENCE(10))", + }, + } + return SheetDataInput(data=data) + + +@pytest.mark.parametrize( + "col, expected_index", + [ + ("A", 0), + ("B", 1), + ("Z", 25), + ("AA", 26 + 0), + ("AZ", (1 * 26) + 25), + ("BA", (2 * 26) + 0), + ("ZZ", (26 * 26) + 25), + ("AAA", (1 * 26 * 26) + (1 * 26) + 0), + ("AAB", (1 * 26 * 26) + (1 * 26) + 1), + ("QED", (17 * 26 * 26) + (5 * 26) + 3), + ], +) +def test_col_to_index(col, expected_index): + assert col_to_index(col) == expected_index + + +@pytest.mark.parametrize( + "index, expected_col", + [ + (0, "A"), + (1, "B"), + (25, "Z"), + (26 + 0, "AA"), + ((1 * 26) + 25, "AZ"), + ((2 * 26) + 0, "BA"), + ((26 * 26) + 25, "ZZ"), + ((1 * 26 * 26) + (1 * 26) + 0, "AAA"), + ((1 * 26 * 26) + (1 * 26) + 1, "AAB"), + ((17 * 26 * 26) + (5 * 26) + 3, "QED"), + ], +) +def test_index_to_col(index, expected_col): + assert index_to_col(index) == expected_col + + +@pytest.mark.parametrize( + "col1, col2, expected_result", + [ + ("A", "B", False), + ("B", "A", True), + ("AA", "AB", False), + ("AB", "AA", True), + ("A", "AA", False), + ("AA", "A", True), + ("Z", "AA", False), + ("AA", "Z", True), + ("AAA", "AAB", False), + ("AAB", "AAA", True), + ("QED", "QEE", False), + ("QEE", "QED", True), + ], +) +def test_is_col_greater(col1, col2, expected_result): + assert is_col_greater(col1, col2) == expected_result + + +def test_compute_sheet_data_dimensions(sheet_data_input_fixture): + (min_row, max_row), (min_col_index, max_col_index) = compute_sheet_data_dimensions( + sheet_data_input_fixture + ) + + expected_min_row = 1 + expected_max_row = 44 + expected_min_col_index = 0 # Column "A" + expected_max_col_index = 8 # Column "I" + + assert min_row == expected_min_row + assert max_row == expected_max_row + assert min_col_index == expected_min_col_index + assert max_col_index == expected_max_col_index + + +def test_create_sheet_properties(): + sheet_properties = create_sheet_properties( + sheet_id=1, + title="Sheet1", + row_count=10000, + column_count=260, + ) + + assert sheet_properties.sheetId == 1 + assert sheet_properties.title == "Sheet1" + assert sheet_properties.gridProperties.rowCount == 10000 + assert sheet_properties.gridProperties.columnCount == 260 + + +@pytest.mark.parametrize( + "row_numbers, expected_groups", + [ + ([], []), + ([5, 6, 7], [[5, 6, 7]]), + ( + [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 18, 19, 20], + [[1, 2, 3], [5, 6, 7, 8, 9, 10, 11], [18, 19, 20]], + ), + ], +) +def test_group_contiguous_rows(row_numbers, expected_groups): + grouped_rows = group_contiguous_rows(row_numbers) + assert grouped_rows == expected_groups + + +@pytest.mark.parametrize( + "input_value, expected_key, expected_value, expected_type, expected_pattern", + [ + (1234, "numberValue", 1234, NumberFormatType.NUMBER, "#,##0"), + (1.234, "numberValue", 1.234, NumberFormatType.NUMBER, "#,##0.00"), + ("$100", "numberValue", 100, NumberFormatType.CURRENCY, "$#,##0"), + ("$100.50", "numberValue", 100.50, NumberFormatType.CURRENCY, "$#,##0.00"), + ("75%", "numberValue", 75.00, NumberFormatType.PERCENT, "0.00%"), + ("75.34%", "numberValue", 75.34, NumberFormatType.PERCENT, "0.00%"), + ("$1abc", "stringValue", "$1abc", None, None), + ("abc7%", "stringValue", "abc7%", None, None), + ("=SUM(A1:B1)", "formulaValue", "=SUM(A1:B1)", None, None), + (True, "boolValue", True, None, None), + ], +) +def test_create_cell_data( + input_value, expected_key, expected_value, expected_type, expected_pattern +): + cell_data = create_cell_data(input_value) + expected_cell_value = CellExtendedValue(**{expected_key: expected_value}) + assert cell_data.userEnteredValue == expected_cell_value + if expected_type is None: + assert cell_data.userEnteredFormat is None + else: + assert cell_data.userEnteredFormat is not None + assert cell_data.userEnteredFormat.numberFormat.type == expected_type + assert cell_data.userEnteredFormat.numberFormat.pattern == expected_pattern + + +def test_create_row_data(): + row_data = { + "A": 1, # Column index 0 + "B": 2.5, # Column index 1 + "AA": "test", # Column index 26 + "BA": True, # Column index 52 + "BB": "=SUM(A1:B1)", # Column index 53 + } + min_col_index = 0 # Column "A" + max_col_index = 53 # Column "BB" + + expected_row_data = RowData( + values=[ + CellData(userEnteredValue=CellExtendedValue(stringValue="")) + for _ in range(max_col_index + 1) + ] + ) + expected_row_data.values[0].userEnteredValue = CellExtendedValue(numberValue=1) + expected_row_data.values[1].userEnteredValue = CellExtendedValue(numberValue=2.5) + expected_row_data.values[26].userEnteredValue = CellExtendedValue(stringValue="test") + expected_row_data.values[52].userEnteredValue = CellExtendedValue(boolValue=True) + expected_row_data.values[53].userEnteredValue = CellExtendedValue(formulaValue="=SUM(A1:B1)") + + row_data = create_row_data(row_data, min_col_index, max_col_index) + + assert len(row_data.values) == len(expected_row_data.values) + for cell, expected in zip(row_data.values, expected_row_data.values, strict=False): + assert cell.userEnteredValue == expected.userEnteredValue + + +def test_create_sheet_data(): + from arcade_google_sheets.models import CellData, CellExtendedValue, SheetDataInput + from arcade_google_sheets.utils import create_cell_data + + test_data = { + 2: {"B": "row2B", "C": 200}, + 3: {"B": "row3B"}, + 5: {"A": "=SUM(A1:A1)", "C": "row5C"}, + } + sheet_data_input = SheetDataInput(data=test_data) + min_col_index = 0 # Column "A" + max_col_index = 2 # Column "C" + + grid_data_list = create_sheet_data(sheet_data_input, min_col_index, max_col_index) + + assert len(grid_data_list) == 2, "Should have two groups of contiguous rows" + + group1 = grid_data_list[0] + assert group1.startRow == 1 + assert group1.startColumn == min_col_index + assert len(group1.rowData) == 2 + + row2_cells = group1.rowData[0].values + expected_row2 = [ + CellData(userEnteredValue=CellExtendedValue(stringValue="")), + create_cell_data("row2B"), + create_cell_data(200), + ] + for cell, expected in zip(row2_cells, expected_row2, strict=False): + assert cell.userEnteredValue == expected.userEnteredValue + + row3_cells = group1.rowData[1].values + expected_row3 = [ + CellData(userEnteredValue=CellExtendedValue(stringValue="")), + create_cell_data("row3B"), + CellData(userEnteredValue=CellExtendedValue(stringValue="")), + ] + for cell, expected in zip(row3_cells, expected_row3, strict=False): + assert cell.userEnteredValue == expected.userEnteredValue + + group2 = grid_data_list[1] + assert group2.startRow == 4 + assert group2.startColumn == min_col_index + assert len(group2.rowData) == 1 + + row5_cells = group2.rowData[0].values + expected_row5 = [ + create_cell_data("=SUM(A1:A1)"), + CellData(userEnteredValue=CellExtendedValue(stringValue="")), + create_cell_data("row5C"), + ] + for cell, expected in zip(row5_cells, expected_row5, strict=False): + assert cell.userEnteredValue == expected.userEnteredValue + + +@pytest.mark.parametrize( + "cell, expected", + [ + ({}, ""), + ({"userEnteredValue": {}}, ""), + ({"userEnteredValue": {"stringValue": "hello"}}, "hello"), + ({"userEnteredValue": {"numberValue": 123}}, 123), + ({"userEnteredValue": {"boolValue": True}}, True), + ({"userEnteredValue": {"formulaValue": "=SUM(A1:A2)"}}, "=SUM(A1:A2)"), + ], +) +def test_extract_user_entered_cell_value(cell, expected): + result = extract_user_entered_cell_value(cell) + assert result == expected + + +def test_process_row_empty(): + row = {} + assert process_row(row, 0) == {} + + +def test_process_row_non_empty(): + row = { + "values": [ + {"userEnteredValue": {"stringValue": "cell1"}, "formattedValue": "cell1"}, + {"userEnteredValue": {}}, # should be ignored + {"userEnteredValue": {"formulaValue": "=C1+D4"}, "formattedValue": 42}, + {"userEnteredValue": {"stringValue": ""}, "formattedValue": ""}, # should be ignored + {"userEnteredValue": {"boolValue": False}, "formattedValue": False}, + ] + } + expected = { + "A": {"userEnteredValue": "cell1", "formattedValue": "cell1"}, + "C": {"userEnteredValue": "=C1+D4", "formattedValue": 42}, + "E": {"userEnteredValue": False, "formattedValue": False}, + } + + assert process_row(row, 0) == expected + + +def test_process_row_with_start_index(): + row = { + "values": [ + {"userEnteredValue": {"stringValue": "x"}, "formattedValue": "x"}, + {"userEnteredValue": {"formulaValue": "=C1+D4"}, "formattedValue": "$10.00"}, + ] + } + expected = { + "C": {"userEnteredValue": "x", "formattedValue": "x"}, + "D": {"userEnteredValue": "=C1+D4", "formattedValue": "$10.00"}, + } + + assert process_row(row, 2) == expected + + +def test_convert_api_grid_data_to_dict_single_grid(): + data = [ + { + "startRow": 0, + "startColumn": 0, + "rowData": [ + { + "values": [ + {"userEnteredValue": {"stringValue": "A1"}, "formattedValue": "A1"}, + {"userEnteredValue": {"numberValue": 1}, "formattedValue": 1}, + ] + }, + { + "values": [ + {"userEnteredValue": {"stringValue": "A2"}, "formattedValue": "A2"}, + {"userEnteredValue": {"numberValue": 2}, "formattedValue": 2}, + ] + }, + { + "values": [ + {"userEnteredValue": {}}, + { + "userEnteredValue": {"stringValue": "ignored"}, + "formattedValue": "ignored", + }, + {"userEnteredValue": {"numberValue": 3}, "formattedValue": 3}, + ] + }, + ], + } + ] + result = convert_api_grid_data_to_dict(data) + expected = { + 1: { + "A": {"userEnteredValue": "A1", "formattedValue": "A1"}, + "B": {"userEnteredValue": 1, "formattedValue": 1}, + }, + 2: { + "A": {"userEnteredValue": "A2", "formattedValue": "A2"}, + "B": {"userEnteredValue": 2, "formattedValue": 2}, + }, + 3: { + "B": {"userEnteredValue": "ignored", "formattedValue": "ignored"}, + "C": {"userEnteredValue": 3, "formattedValue": 3}, + }, + } + + assert result == expected + + +def test_convert_api_grid_data_to_dict_multiple_grids(): + data = [ + { + "startRow": 5, + "startColumn": 1, + "rowData": [ + { + "values": [ + {"userEnteredValue": {"numberValue": 100}, "formattedValue": 100}, + {"userEnteredValue": {"stringValue": "=SUM(A1:A2)"}, "formattedValue": 23}, + ] + } + ], + }, + { + "startRow": 0, + "startColumn": 0, + "rowData": [ + { + "values": [ + {"userEnteredValue": {"stringValue": "First"}, "formattedValue": "First"}, + {"userEnteredValue": {"numberValue": 10}, "formattedValue": 10}, + ] + } + ], + }, + ] + result = convert_api_grid_data_to_dict(data) + expected = { + 1: { + "A": {"userEnteredValue": "First", "formattedValue": "First"}, + "B": {"userEnteredValue": 10, "formattedValue": 10}, + }, + 6: { + "B": {"userEnteredValue": 100, "formattedValue": 100}, + "C": {"userEnteredValue": "=SUM(A1:A2)", "formattedValue": 23}, + }, + } + + assert result == expected + + +def test_convert_api_grid_data_to_dict_empty_rows(): + data = [ + { + "startRow": 10, + "startColumn": 0, + "rowData": [ + {"values": [{"userEnteredValue": {}, "formattedValue": ""}]}, + {"values": []}, + ], + } + ] + result = convert_api_grid_data_to_dict(data) + expected = {} + + assert result == expected + + +FAKE_SHEET_RESPONSE = { + "sheets": [ + {"properties": {"title": "Sheet1", "gridProperties": {"rowCount": 10, "columnCount": 6}}} + ] +} + + +@patch("arcade_google_sheets.utils.build_sheets_service") +def test_validate_write_to_cell_params_valid(mock_build): + mock_service = MagicMock() + mock_service.spreadsheets().get().execute.return_value = FAKE_SHEET_RESPONSE + mock_build.return_value = mock_service + + service = mock_build("dummy_token") + + validate_write_to_cell_params( + service=service, + spreadsheet_id="dummy_id", + sheet_name="Sheet1", + column="B", + row=5, + ) + + +@patch("arcade_google_sheets.utils.build_sheets_service") +def test_validate_write_to_cell_params_invalid_sheet_name(mock_build): + mock_service = MagicMock() + mock_service.spreadsheets().get().execute.return_value = FAKE_SHEET_RESPONSE + mock_build.return_value = mock_service + + service = mock_build("dummy_token") + + with pytest.raises(RetryableToolError) as excinfo: + validate_write_to_cell_params( + service=service, + spreadsheet_id="dummy_id", + sheet_name="NonExistentSheet", + column="A", + row=5, + ) + assert "Sheet name NonExistentSheet not found" in str(excinfo.value) + + +@patch("arcade_google_sheets.utils.build_sheets_service") +def test_validate_write_to_cell_params_row_out_of_bounds(mock_build): + mock_service = MagicMock() + mock_service.spreadsheets().get().execute.return_value = FAKE_SHEET_RESPONSE + mock_build.return_value = mock_service + + service = mock_build("dummy_token") + + out_of_bounds_row = 15 + with pytest.raises(ToolExecutionError) as excinfo: + validate_write_to_cell_params( + service=service, + spreadsheet_id="dummy_id", + sheet_name="Sheet1", + column="A", + row=out_of_bounds_row, + ) + assert f"Row {out_of_bounds_row} is out of bounds" in str(excinfo.value) + + +@patch("arcade_google_sheets.utils.build_sheets_service") +def test_validate_write_to_cell_params_column_out_of_bounds(mock_build): + mock_service = MagicMock() + mock_service.spreadsheets().get().execute.return_value = FAKE_SHEET_RESPONSE + mock_build.return_value = mock_service + + service = mock_build("dummy_token") + + out_of_bounds_column = "Z" + with pytest.raises(ToolExecutionError) as excinfo: + validate_write_to_cell_params( + service=service, + spreadsheet_id="dummy_id", + sheet_name="Sheet1", + column=out_of_bounds_column, + row=5, + ) + assert f"Column {out_of_bounds_column} is out of bounds" in str(excinfo.value) diff --git a/toolkits/google_shopping/.pre-commit-config.yaml b/toolkits/google_shopping/.pre-commit-config.yaml new file mode 100644 index 00000000..cbf1287c --- /dev/null +++ b/toolkits/google_shopping/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/google_shopping/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/google_shopping/.ruff.toml b/toolkits/google_shopping/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/google_shopping/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/google_shopping/LICENSE b/toolkits/google_shopping/LICENSE new file mode 100644 index 00000000..45f53e20 --- /dev/null +++ b/toolkits/google_shopping/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Arcade + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/google_shopping/Makefile b/toolkits/google_shopping/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/google_shopping/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/google_shopping/arcade_google_shopping/__init__.py b/toolkits/google_shopping/arcade_google_shopping/__init__.py new file mode 100644 index 00000000..27cf60fb --- /dev/null +++ b/toolkits/google_shopping/arcade_google_shopping/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_shopping.tools import search_products + +__all__ = ["search_products"] diff --git a/toolkits/google_shopping/arcade_google_shopping/constants.py b/toolkits/google_shopping/arcade_google_shopping/constants.py new file mode 100644 index 00000000..600b7307 --- /dev/null +++ b/toolkits/google_shopping/arcade_google_shopping/constants.py @@ -0,0 +1,10 @@ +import os + +DEFAULT_GOOGLE_LANGUAGE = os.getenv("ARCADE_GOOGLE_LANGUAGE", "en") +DEFAULT_GOOGLE_COUNTRY = os.getenv("ARCADE_GOOGLE_COUNTRY") +DEFAULT_GOOGLE_SHOPPING_LANGUAGE = os.getenv( + "ARCADE_GOOGLE_SHOPPING_LANGUAGE", DEFAULT_GOOGLE_LANGUAGE +) +DEFAULT_GOOGLE_SHOPPING_COUNTRY = os.getenv( + "ARCADE_GOOGLE_SHOPPING_COUNTRY", DEFAULT_GOOGLE_COUNTRY +) diff --git a/toolkits/google_shopping/arcade_google_shopping/exceptions.py b/toolkits/google_shopping/arcade_google_shopping/exceptions.py new file mode 100644 index 00000000..3c039160 --- /dev/null +++ b/toolkits/google_shopping/arcade_google_shopping/exceptions.py @@ -0,0 +1,25 @@ +import json + +from arcade_tdk.errors import RetryableToolError + +from arcade_google_shopping.google_data import COUNTRY_CODES, LANGUAGE_CODES + + +class GoogleRetryableError(RetryableToolError): + pass + + +class CountryNotFoundError(GoogleRetryableError): + def __init__(self, country: str | None) -> None: + valid_countries = json.dumps(COUNTRY_CODES, default=str) + message = f"Country not found: '{country}'." + additional_message = f"Valid countries are: {valid_countries}" + super().__init__(message, additional_prompt_content=additional_message) + + +class LanguageNotFoundError(GoogleRetryableError): + def __init__(self, language: str | None) -> None: + valid_languages = json.dumps(LANGUAGE_CODES, default=str) + message = f"Language not found: '{language}'." + additional_message = f"Valid languages are: {valid_languages}" + super().__init__(message, additional_prompt_content=additional_message) diff --git a/toolkits/google_shopping/arcade_google_shopping/google_data.py b/toolkits/google_shopping/arcade_google_shopping/google_data.py new file mode 100644 index 00000000..8fff7a1e --- /dev/null +++ b/toolkits/google_shopping/arcade_google_shopping/google_data.py @@ -0,0 +1,468 @@ +COUNTRY_CODES = { + "af": "Afghanistan", + "al": "Albania", + "dz": "Algeria", + "as": "American Samoa", + "ad": "Andorra", + "ao": "Angola", + "ai": "Anguilla", + "aq": "Antarctica", + "ag": "Antigua and Barbuda", + "ar": "Argentina", + "am": "Armenia", + "aw": "Aruba", + "au": "Australia", + "at": "Austria", + "az": "Azerbaijan", + "bs": "Bahamas", + "bh": "Bahrain", + "bd": "Bangladesh", + "bb": "Barbados", + "by": "Belarus", + "be": "Belgium", + "bz": "Belize", + "bj": "Benin", + "bm": "Bermuda", + "bt": "Bhutan", + "bo": "Bolivia", + "ba": "Bosnia and Herzegovina", + "bw": "Botswana", + "bv": "Bouvet Island", + "br": "Brazil", + "io": "British Indian Ocean Territory", + "bn": "Brunei Darussalam", + "bg": "Bulgaria", + "bf": "Burkina Faso", + "bi": "Burundi", + "kh": "Cambodia", + "cm": "Cameroon", + "ca": "Canada", + "cv": "Cape Verde", + "ky": "Cayman Islands", + "cf": "Central African Republic", + "td": "Chad", + "cl": "Chile", + "cn": "China", + "cx": "Christmas Island", + "cc": "Cocos (Keeling) Islands", + "co": "Colombia", + "km": "Comoros", + "cg": "Congo", + "cd": "Congo, the Democratic Republic of the", + "ck": "Cook Islands", + "cr": "Costa Rica", + "ci": "Cote D'ivoire", + "hr": "Croatia", + "cu": "Cuba", + "cy": "Cyprus", + "cz": "Czech Republic", + "dk": "Denmark", + "dj": "Djibouti", + "dm": "Dominica", + "do": "Dominican Republic", + "ec": "Ecuador", + "eg": "Egypt", + "sv": "El Salvador", + "gq": "Equatorial Guinea", + "er": "Eritrea", + "ee": "Estonia", + "et": "Ethiopia", + "fk": "Falkland Islands (Malvinas)", + "fo": "Faroe Islands", + "fj": "Fiji", + "fi": "Finland", + "fr": "France", + "gf": "French Guiana", + "pf": "French Polynesia", + "tf": "French Southern Territories", + "ga": "Gabon", + "gm": "Gambia", + "ge": "Georgia", + "de": "Germany", + "gh": "Ghana", + "gi": "Gibraltar", + "gr": "Greece", + "gl": "Greenland", + "gd": "Grenada", + "gp": "Guadeloupe", + "gu": "Guam", + "gt": "Guatemala", + "gg": "Guernsey", + "gn": "Guinea", + "gw": "Guinea-Bissau", + "gy": "Guyana", + "ht": "Haiti", + "hm": "Heard Island and Mcdonald Islands", + "va": "Holy See (Vatican City State)", + "hn": "Honduras", + "hk": "Hong Kong", + "hu": "Hungary", + "is": "Iceland", + "in": "India", + "id": "Indonesia", + "ir": "Iran, Islamic Republic of", + "iq": "Iraq", + "ie": "Ireland", + "im": "Isle of Man", + "il": "Israel", + "it": "Italy", + "je": "Jersey", + "jm": "Jamaica", + "jp": "Japan", + "jo": "Jordan", + "kz": "Kazakhstan", + "ke": "Kenya", + "ki": "Kiribati", + "kp": "Korea, Democratic People's Republic of", + "kr": "Korea, Republic of", + "kw": "Kuwait", + "kg": "Kyrgyzstan", + "la": "Lao People's Democratic Republic", + "lv": "Latvia", + "lb": "Lebanon", + "ls": "Lesotho", + "lr": "Liberia", + "ly": "Libyan Arab Jamahiriya", + "li": "Liechtenstein", + "lt": "Lithuania", + "lu": "Luxembourg", + "mo": "Macao", + "mk": "Macedonia, the Former Yugosalv Republic of", + "mg": "Madagascar", + "mw": "Malawi", + "my": "Malaysia", + "mv": "Maldives", + "ml": "Mali", + "mt": "Malta", + "mh": "Marshall Islands", + "mq": "Martinique", + "mr": "Mauritania", + "mu": "Mauritius", + "yt": "Mayotte", + "mx": "Mexico", + "fm": "Micronesia, Federated States of", + "md": "Moldova, Republic of", + "mc": "Monaco", + "mn": "Mongolia", + "me": "Montenegro", + "ms": "Montserrat", + "ma": "Morocco", + "mz": "Mozambique", + "mm": "Myanmar", + "na": "Namibia", + "nr": "Nauru", + "np": "Nepal", + "nl": "Netherlands", + "an": "Netherlands Antilles", + "nc": "New Caledonia", + "nz": "New Zealand", + "ni": "Nicaragua", + "ne": "Niger", + "ng": "Nigeria", + "nu": "Niue", + "nf": "Norfolk Island", + "mp": "Northern Mariana Islands", + "no": "Norway", + "om": "Oman", + "pk": "Pakistan", + "pw": "Palau", + "ps": "Palestinian Territory, Occupied", + "pa": "Panama", + "pg": "Papua New Guinea", + "py": "Paraguay", + "pe": "Peru", + "ph": "Philippines", + "pn": "Pitcairn", + "pl": "Poland", + "pt": "Portugal", + "pr": "Puerto Rico", + "qa": "Qatar", + "re": "Reunion", + "ro": "Romania", + "ru": "Russian Federation", + "rw": "Rwanda", + "sh": "Saint Helena", + "kn": "Saint Kitts and Nevis", + "lc": "Saint Lucia", + "pm": "Saint Pierre and Miquelon", + "vc": "Saint Vincent and the Grenadines", + "ws": "Samoa", + "sm": "San Marino", + "st": "Sao Tome and Principe", + "sa": "Saudi Arabia", + "sn": "Senegal", + "rs": "Serbia", + "sc": "Seychelles", + "sl": "Sierra Leone", + "sg": "Singapore", + "sk": "Slovakia", + "si": "Slovenia", + "sb": "Solomon Islands", + "so": "Somalia", + "za": "South Africa", + "gs": "South Georgia and the South Sandwich Islands", + "es": "Spain", + "lk": "Sri Lanka", + "sd": "Sudan", + "sr": "Suriname", + "sj": "Svalbard and Jan Mayen", + "sz": "Swaziland", + "se": "Sweden", + "ch": "Switzerland", + "sy": "Syrian Arab Republic", + "tw": "Taiwan, Province of China", + "tj": "Tajikistan", + "tz": "Tanzania, United Republic of", + "th": "Thailand", + "tl": "Timor-Leste", + "tg": "Togo", + "tk": "Tokelau", + "to": "Tonga", + "tt": "Trinidad and Tobago", + "tn": "Tunisia", + "tr": "Turkiye", + "tm": "Turkmenistan", + "tc": "Turks and Caicos Islands", + "tv": "Tuvalu", + "ug": "Uganda", + "ua": "Ukraine", + "ae": "United Arab Emirates", + "uk": "United Kingdom", + "gb": "United Kingdom", + "us": "United States", + "um": "United States Minor Outlying Islands", + "uy": "Uruguay", + "uz": "Uzbekistan", + "vu": "Vanuatu", + "ve": "Venezuela", + "vn": "Viet Nam", + "vg": "Virgin Islands, British", + "vi": "Virgin Islands, U.S.", + "wf": "Wallis and Futuna", + "eh": "Western Sahara", + "ye": "Yemen", + "zm": "Zambia", + "zw": "Zimbabwe", +} + + +LANGUAGE_CODES = { + "ar": "Arabic", + "bn": "Bengali", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "fi": "Finnish", + "fr": "French", + "hi": "Hindi", + "hu": "Hungarian", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "ko": "Korean", + "nl": "Dutch", + "ms": "Malay", + "no": "Norwegian", + "pcm": "Nigerian Pidgin", + "pl": "Polish", + "pt": "Portuguese", + "pt-br": "Portuguese (Brazil)", + "pt-pt": "Portuguese (Portugal)", + "ru": "Russian", + "sv": "Swedish", + "tl": "Filipino", + "tr": "Turkish", + "uk": "Ukrainian", + "zh": "Chinese", + "zh-cn": "Chinese (Simplified)", + "zh-tw": "Chinese (Traditional)", +} + +GOOGLE_DOMAIN_BY_COUNTRY_CODE = { + "ad": "google.ad", + "ae": "google.ae", + "al": "google.al", + "am": "google.am", + "as": "google.as", + "at": "google.at", + "az": "google.az", + "ba": "google.ba", + "be": "google.be", + "bf": "google.bf", + "bg": "google.bg", + "bi": "google.bi", + "bj": "google.bj", + "bs": "google.bs", + "bt": "google.bt", + "by": "google.by", + "ca": "google.ca", + "cg": "google.cg", + "cf": "google.cf", + "ch": "google.ch", + "ci": "google.ci", + "cl": "google.cl", + "cm": "google.cm", + "ao": "google.co.ao", + "bw": "google.co.bw", + "ck": "google.co.ck", + "cr": "google.co.cr", + "id": "google.co.id", + "il": "google.co.il", + "in": "google.co.in", + "jp": "google.co.jp", + "ke": "google.co.ke", + "kr": "google.co.kr", + "ls": "google.co.ls", + "ma": "google.co.ma", + "mz": "google.co.mz", + "nz": "google.co.nz", + "th": "google.co.th", + "tz": "google.co.tz", + "ug": "google.co.ug", + "uk": "google.co.uk", + "uz": "google.co.uz", + "ve": "google.co.ve", + "vi": "google.co.vi", + "za": "google.co.za", + "zm": "google.co.zm", + "zw": "google.co.zw", + "us": "google.com", + "af": "google.com.af", + "ag": "google.com.ag", + "ai": "google.com.ai", + "ar": "google.com.ar", + "au": "google.com.au", + "bd": "google.com.bd", + "bh": "google.com.bh", + "bn": "google.com.bn", + "bo": "google.com.bo", + "br": "google.com.br", + "bz": "google.com.bz", + "co": "google.com.co", + "cu": "google.com.cu", + "cy": "google.com.cy", + "do": "google.com.do", + "ec": "google.com.ec", + "eg": "google.com.eg", + "et": "google.com.et", + "fj": "google.com.fj", + "gh": "google.com.gh", + "gi": "google.com.gi", + "gt": "google.com.gt", + "hk": "google.com.hk", + "jm": "google.com.jm", + "kh": "google.com.kh", + "kw": "google.com.kw", + "lb": "google.com.lb", + "ly": "google.com.ly", + "mm": "google.com.mm", + "mt": "google.com.mt", + "mx": "google.com.mx", + "my": "google.com.my", + "na": "google.com.na", + "ng": "google.com.ng", + "ni": "google.com.ni", + "np": "google.com.np", + "om": "google.com.om", + "pa": "google.com.pa", + "pe": "google.com.pe", + "pg": "google.com.pg", + "ph": "google.com.ph", + "pk": "google.com.pk", + "pr": "google.com.pr", + "py": "google.com.py", + "qa": "google.com.qa", + "sa": "google.com.sa", + "sb": "google.com.sb", + "sg": "google.com.sg", + "sl": "google.com.sl", + "sv": "google.com.sv", + "tj": "google.com.tj", + "tr": "google.com.tr", + "tw": "google.com.tw", + "ua": "google.com.ua", + "uy": "google.com.uy", + "vc": "google.com.vc", + "vn": "google.com.vn", + "cv": "google.cv", + "cz": "google.cz", + "de": "google.de", + "dj": "google.dj", + "dk": "google.dk", + "dm": "google.dm", + "dz": "google.dz", + "ee": "google.ee", + "es": "google.es", + "fi": "google.fi", + "fm": "google.fm", + "fr": "google.fr", + "ga": "google.ga", + "ge": "google.ge", + "gl": "google.gl", + "gm": "google.gm", + "gp": "google.gp", + "gr": "google.gr", + "gy": "google.gy", + "hn": "google.hn", + "hr": "google.hr", + "ht": "google.ht", + "hu": "google.hu", + "ie": "google.ie", + "iq": "google.iq", + "is": "google.is", + "it": "google.it", + "je": "google.je", + "jo": "google.jo", + "kg": "google.kg", + "ki": "google.ki", + "kz": "google.kz", + "la": "google.la", + "li": "google.li", + "lk": "google.lk", + "lt": "google.lt", + "lu": "google.lu", + "lv": "google.lv", + "md": "google.md", + "mg": "google.mg", + "mk": "google.mk", + "ml": "google.ml", + "mn": "google.mn", + "ms": "google.ms", + "mu": "google.mu", + "mv": "google.mv", + "mw": "google.mw", + "ne": "google.ne", + "nl": "google.nl", + "no": "google.no", + "nr": "google.nr", + "nu": "google.nu", + "pl": "google.pl", + "ps": "google.ps", + "pt": "google.pt", + "ro": "google.ro", + "rs": "google.rs", + "ru": "google.ru", + "rw": "google.rw", + "sc": "google.sc", + "se": "google.se", + "sh": "google.sh", + "si": "google.si", + "sk": "google.sk", + "sm": "google.sm", + "sn": "google.sn", + "so": "google.so", + "sr": "google.sr", + "td": "google.td", + "tg": "google.tg", + "tk": "google.tk", + "tl": "google.tl", + "tm": "google.tm", + "tn": "google.tn", + "to": "google.to", + "tt": "google.tt", + "vg": "google.vg", + "vu": "google.vu", + "ws": "google.ws", +} diff --git a/toolkits/google_shopping/arcade_google_shopping/tools/__init__.py b/toolkits/google_shopping/arcade_google_shopping/tools/__init__.py new file mode 100644 index 00000000..a6d408ed --- /dev/null +++ b/toolkits/google_shopping/arcade_google_shopping/tools/__init__.py @@ -0,0 +1,3 @@ +from arcade_google_shopping.tools.google_shopping import search_products + +__all__ = ["search_products"] diff --git a/toolkits/google_shopping/arcade_google_shopping/tools/google_shopping.py b/toolkits/google_shopping/arcade_google_shopping/tools/google_shopping.py new file mode 100644 index 00000000..0a7e18c0 --- /dev/null +++ b/toolkits/google_shopping/arcade_google_shopping/tools/google_shopping.py @@ -0,0 +1,66 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from arcade_tdk.errors import ToolExecutionError + +from arcade_google_shopping.constants import ( + DEFAULT_GOOGLE_SHOPPING_COUNTRY, + DEFAULT_GOOGLE_SHOPPING_LANGUAGE, +) +from arcade_google_shopping.google_data import GOOGLE_DOMAIN_BY_COUNTRY_CODE +from arcade_google_shopping.utils import ( + call_serpapi, + extract_shopping_results, + prepare_params, + resolve_country_code, + resolve_language_code, +) + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_products( + context: ToolContext, + keywords: Annotated[ + str, + "Keywords to search for products in Google Shopping. E.g. 'Apple iPhone'.", + ], + country_code: Annotated[ + str | None, + "2-character country code to search for products in Google Shopping. " + f"E.g. 'us' (United States). Defaults to '{DEFAULT_GOOGLE_SHOPPING_COUNTRY or 'us'}'.", + ] = DEFAULT_GOOGLE_SHOPPING_COUNTRY, + language_code: Annotated[ + str | None, + "2-character language code to search for products on Google Shopping. E.g. 'en' (English). " + f"Defaults to '{DEFAULT_GOOGLE_SHOPPING_LANGUAGE or 'en'}'.", + ] = DEFAULT_GOOGLE_SHOPPING_LANGUAGE, +) -> Annotated[dict[str, list[dict[str, Any]]], "Products on Google Shopping."]: + """Search for products on Google Shopping related to a given query.""" + country_code = resolve_country_code(country_code, DEFAULT_GOOGLE_SHOPPING_COUNTRY) + language_code = resolve_language_code(language_code, DEFAULT_GOOGLE_SHOPPING_LANGUAGE) + + if not isinstance(country_code, str): + country_code = "us" + + if not isinstance(language_code, str): + language_code = "en" + + google_domain = GOOGLE_DOMAIN_BY_COUNTRY_CODE.get(country_code, "google.com") + + params = prepare_params( + "google_shopping", + q=keywords, + gl=country_code, + hl=language_code, + google_domain=google_domain, + ) + + response = call_serpapi(context, params) + + if response.get("error"): + error_msg = response.get("error") or "Unknown Google Shopping Error" + raise ToolExecutionError(error_msg) + + return { + "products": extract_shopping_results(response.get("shopping_results", [])), + } diff --git a/toolkits/google_shopping/arcade_google_shopping/utils.py b/toolkits/google_shopping/arcade_google_shopping/utils.py new file mode 100644 index 00000000..9ef8433d --- /dev/null +++ b/toolkits/google_shopping/arcade_google_shopping/utils.py @@ -0,0 +1,117 @@ +import re +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + +from arcade_google_shopping.constants import ( + DEFAULT_GOOGLE_COUNTRY, + DEFAULT_GOOGLE_LANGUAGE, +) +from arcade_google_shopping.exceptions import CountryNotFoundError, LanguageNotFoundError +from arcade_google_shopping.google_data import COUNTRY_CODES, LANGUAGE_CODES + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) + + +def default_language_code(default_service_language_code: str | None = None) -> str | None: + if isinstance(default_service_language_code, str): + return default_service_language_code.lower() + elif isinstance(DEFAULT_GOOGLE_LANGUAGE, str): + return DEFAULT_GOOGLE_LANGUAGE.lower() + return None + + +def default_country_code(default_service_country_code: str | None = None) -> str | None: + if isinstance(default_service_country_code, str): + return default_service_country_code.lower() + elif isinstance(DEFAULT_GOOGLE_COUNTRY, str): + return DEFAULT_GOOGLE_COUNTRY.lower() + return None + + +def resolve_language_code( + language_code: str | None = None, + default_service_language_code: str | None = None, +) -> str | None: + language_code = language_code or default_language_code(default_service_language_code) + + if isinstance(language_code, str): + language_code = language_code.lower() + if language_code not in LANGUAGE_CODES: + raise LanguageNotFoundError(language_code) + + return language_code + + +def resolve_country_code( + country_code: str | None = None, + default_service_country_code: str | None = None, +) -> str | None: + country_code = country_code or default_country_code(default_service_country_code) + + if isinstance(country_code, str): + country_code = country_code.lower() + if country_code not in COUNTRY_CODES: + raise CountryNotFoundError(country_code) + + return country_code + + +def extract_shopping_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [ + { + "title": result.get("title"), + "direct_link": result.get("link"), + "google_link": result.get("product_link"), + "source": result.get("source"), + "price": result.get("price"), + "product_rating": result.get("rating"), + "product_reviews": result.get("reviews"), + "store_rating": result.get("store_rating"), + "store_reviews": result.get("store_reviews"), + "delivery": result.get("delivery"), + } + for result in results + ] diff --git a/toolkits/google_shopping/pyproject.toml b/toolkits/google_shopping/pyproject.toml new file mode 100644 index 00000000..787c6ea5 --- /dev/null +++ b/toolkits/google_shopping/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_google_shopping" +version = "2.0.0" +description = "Arcade.dev LLM tools for shopping via Google Shopping" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "serpapi>=0.1.5,<1.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_google_shopping/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_google_shopping",] diff --git a/toolkits/outlook_calendar/.pre-commit-config.yaml b/toolkits/outlook_calendar/.pre-commit-config.yaml new file mode 100644 index 00000000..7bbdadc3 --- /dev/null +++ b/toolkits/outlook_calendar/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/outlook_calendar/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/outlook_calendar/.ruff.toml b/toolkits/outlook_calendar/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/outlook_calendar/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/outlook_calendar/LICENSE b/toolkits/outlook_calendar/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/outlook_calendar/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/outlook_calendar/Makefile b/toolkits/outlook_calendar/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/outlook_calendar/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/__init__.py b/toolkits/outlook_calendar/arcade_outlook_calendar/__init__.py new file mode 100644 index 00000000..430f39c2 --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/__init__.py @@ -0,0 +1,7 @@ +from arcade_outlook_calendar.tools import ( + create_event, + get_event, + list_events_in_time_range, +) + +__all__ = ["create_event", "get_event", "list_events_in_time_range"] diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/_utils.py b/toolkits/outlook_calendar/arcade_outlook_calendar/_utils.py new file mode 100644 index 00000000..c6ce1577 --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/_utils.py @@ -0,0 +1,225 @@ +import re +from datetime import datetime +from typing import Any + +import pytz +from arcade_tdk.errors import ToolExecutionError +from kiota_abstractions.base_request_configuration import RequestConfiguration +from kiota_abstractions.headers_collection import HeadersCollection +from msgraph import GraphServiceClient +from msgraph.generated.users.item.mailbox_settings.mailbox_settings_request_builder import ( + MailboxSettingsRequestBuilder, +) + +from arcade_outlook_calendar.constants import WINDOWS_TO_IANA + + +def validate_date_times(start_date_time: str, end_date_time: str) -> None: + """ + Validate date times are in ISO 8601 format and + that end time is after start time (ignoring timezone offsets). + + Args: + start_date_time: The start date time string to validate. + end_date_time: The end date time string to validate. + + Raises: + ValueError: If the date times are not in ISO 8601 format + ToolExecutionError: If end time is not after start time. + + Note: + This function ignores timezone offsets. + """ + # parse into offset-aware datetimes + start_aware = datetime.fromisoformat(start_date_time) + end_aware = datetime.fromisoformat(end_date_time) + + # drop tzinfo to treat both as naïve local times + start_naive = start_aware.replace(tzinfo=None) + end_naive = end_aware.replace(tzinfo=None) + + if start_naive >= end_naive: + raise ToolExecutionError( + message="Start time must be before end time", + developer_message=( + f"The start time '{start_naive}' is not before the end time '{end_naive}'" + ), + ) + + +def prepare_meeting_body( + body: str, custom_meeting_url: str | None, is_online_meeting: bool +) -> tuple[str, bool]: + """Prepare meeting body and determine final online meeting status. + + Args: + body: The original meeting body text + custom_meeting_url: Custom URL for the meeting, if one exists + is_online_meeting: Whether this should be an online meeting + + Returns: + tuple: (Updated meeting body, final online meeting status) + + Note: + If a custom meeting URL is provided, is_online_meeting will be set to False + to prevent Microsoft from generating its own meeting URL. The custom meeting + URL will then be added to the body of the meeting. + """ + is_online_meeting = not custom_meeting_url and is_online_meeting + + if custom_meeting_url: + body = f"""{body}\n +......................................................................... +Join online meeting +{custom_meeting_url}""" + + return body, is_online_meeting + + +def validate_emails(emails: list[str]) -> None: + """Validate a list of email addresses. + + Args: + emails: The list of email addresses to validate. + + Raises: + ToolExecutionError: If any email address is invalid. + """ + invalid_emails = [] + for email in emails: + if not is_valid_email(email): + invalid_emails.append(email) + if invalid_emails: + raise ToolExecutionError(message=f"Invalid email address(es): {', '.join(invalid_emails)}") + + +def is_valid_email(email: str) -> bool: + """Simple check to see if an email address is valid. + + Args: + email: The email address to check. + + Returns: + True if the email address is valid, False otherwise. + """ + pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + return re.match(pattern, email) is not None + + +def remove_timezone_offset(date_time: str) -> str: + """Remove the timezone offset from the date_time string.""" + return re.sub(r"[+-][0-9]{2}:[0-9]{2}$|Z$", "", date_time) + + +def replace_timezone_offset(date_time: str, time_zone_offset: str) -> str: + """Replace the timezone offset in the date_time string with the time_zone_offset. + + If the date_time str already contains a timezone offset, it will be replaced. + If the date_time str does not contain a timezone offset, the time_zone_offset will be appended + + Args: + date_time: The date_time string to replace the timezone offset in. + time_zone_offset: The timezone offset to replace the existing timezone offset with. + + Returns: + The date_time string with the timezone offset replaced or appended. + """ + date_time = remove_timezone_offset(date_time) + return f"{date_time}{time_zone_offset}" + + +def convert_timezone_to_offset(time_zone: str) -> str: + """ + Convert a timezone (Windows or IANA) to ISO 8601 offset. + First tries Windows timezone format, then IANA, then falls back to UTC if both fail. + + Args: + time_zone: The timezone (Windows or IANA) to convert to ISO 8601 offset. + + Returns: + The timezone offset in ISO 8601 format (e.g. '+08:00', '-07:00', or 'Z' for UTC) + """ + # Try Windows timezone format + iana_timezone = WINDOWS_TO_IANA.get(time_zone) + if iana_timezone: + try: + tz = pytz.timezone(iana_timezone) + now = datetime.now(tz) + tz_offset = now.strftime("%z") + + if len(tz_offset) == 5: # +HHMM format + tz_offset = f"{tz_offset[:3]}:{tz_offset[3:]}" # +HH:MM format + return tz_offset # noqa: TRY300 + except (pytz.exceptions.UnknownTimeZoneError, ValueError): + pass + + # Try IANA timezone format + try: + tz = pytz.timezone(time_zone) + now = datetime.now(tz) + tz_offset = now.strftime("%z") + + if len(tz_offset) == 5: # +HHMM format + tz_offset = f"{tz_offset[:3]}:{tz_offset[3:]}" # +HH:MM format + return tz_offset # noqa: TRY300 + except (pytz.exceptions.UnknownTimeZoneError, ValueError): + # Fallback to UTC + return "Z" + + +async def get_default_calendar_timezone(client: GraphServiceClient) -> str: + """Get the authenticated user's default calendar's timezone. + + Args: + client: The GraphServiceClient to use to get + the authenticated user's default calendar's timezone. + + Returns: + The timezone in "Windows timezone format" or "IANA timezone format". + """ + query_params = MailboxSettingsRequestBuilder.MailboxSettingsRequestBuilderGetQueryParameters( + select=["timeZone"] + ) + request_config = RequestConfiguration( + query_parameters=query_params, + ) + response = await client.me.mailbox_settings.get(request_config) + + if response and response.time_zone: + return response.time_zone + return "UTC" + + +def create_timezone_headers(time_zone: str) -> HeadersCollection: + """ + Create headers with timezone preference. + + Args: + time_zone: The timezone to set in the headers. + + Returns: + Headers collection with timezone preference set. + """ + headers = HeadersCollection() + headers.try_add("Prefer", f'outlook.timezone="{time_zone}"') + return headers + + +def create_timezone_request_config( + time_zone: str, query_parameters: Any | None = None +) -> RequestConfiguration: + """ + Create a request configuration with timezone headers and optional query parameters. + + Args: + time_zone: The timezone to set in the headers. + query_parameters: Optional query parameters to include in the configuration. + + Returns: + Request configuration with timezone headers and optional query parameters. + """ + headers = create_timezone_headers(time_zone) + return RequestConfiguration( + headers=headers, + query_parameters=query_parameters, + ) diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/client.py b/toolkits/outlook_calendar/arcade_outlook_calendar/client.py new file mode 100644 index 00000000..e11d257a --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/client.py @@ -0,0 +1,26 @@ +import datetime +from typing import Any + +from azure.core.credentials import AccessToken, TokenCredential +from msgraph import GraphServiceClient + +DEFAULT_SCOPE = "https://graph.microsoft.com/.default" + + +class StaticTokenCredential(TokenCredential): + """Implementation of TokenCredential protocol to be provided to the MSGraph SDK client""" + + def __init__(self, token: str): + self._token = token + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + # An expiration is required by MSGraph SDK. Set to 1 hour from now. + expires_on = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) + 3600 + return AccessToken(self._token, expires_on) + + +def get_client(token: str) -> GraphServiceClient: + """Create and return a MSGraph SDK client, given the provided token.""" + token_credential = StaticTokenCredential(token) + + return GraphServiceClient(token_credential, scopes=[DEFAULT_SCOPE]) diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/constants.py b/toolkits/outlook_calendar/arcade_outlook_calendar/constants.py new file mode 100644 index 00000000..678e3771 --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/constants.py @@ -0,0 +1,138 @@ +# Maps "Windows timezone format" to "IANA timezone format" +# Does not include all Windows timezones. +WINDOWS_TO_IANA = { + "Dateline Standard Time": "Etc/GMT+12", + "UTC-11": "Etc/GMT+11", + "Aleutian Standard Time": "America/Adak", + "Hawaiian Standard Time": "Pacific/Honolulu", + "Marquesas Standard Time": "Pacific/Marquesas", + "Alaskan Standard Time": "America/Anchorage", + "UTC-09": "Etc/GMT+9", + "Pacific Standard Time (Mexico)": "America/Tijuana", + "UTC-08": "Etc/GMT+8", + "Pacific Standard Time": "America/Los_Angeles", + "US Mountain Standard Time": "America/Phoenix", + "Mountain Standard Time (Mexico)": "America/Chihuahua", + "Mountain Standard Time": "America/Denver", + "Central America Standard Time": "America/Guatemala", + "Central Standard Time": "America/Chicago", + "Easter Island Standard Time": "Pacific/Easter", + "Central Standard Time (Mexico)": "America/Mexico_City", + "Canada Central Standard Time": "America/Regina", + "SA Pacific Standard Time": "America/Bogota", + "Eastern Standard Time (Mexico)": "America/Cancun", + "Eastern Standard Time": "America/New_York", + "Haiti Standard Time": "America/Port-au-Prince", + "Cuba Standard Time": "America/Havana", + "US Eastern Standard Time": "America/Indianapolis", + "Turks And Caicos Standard Time": "America/Grand_Turk", + "Paraguay Standard Time": "America/Asuncion", + "Atlantic Standard Time": "America/Halifax", + "Venezuela Standard Time": "America/Caracas", + "Central Brazilian Standard Time": "America/Cuiaba", + "SA Western Standard Time": "America/La_Paz", + "Pacific SA Standard Time": "America/Santiago", + "Newfoundland Standard Time": "America/St_Johns", + "Tocantins Standard Time": "America/Araguaina", + "E. South America Standard Time": "America/Sao_Paulo", + "SA Eastern Standard Time": "America/Cayenne", + "Argentina Standard Time": "America/Buenos_Aires", + "Greenland Standard Time": "America/Godthab", + "Montevideo Standard Time": "America/Montevideo", + "Magallanes Standard Time": "America/Punta_Arenas", + "Saint Pierre Standard Time": "America/Miquelon", + "Bahia Standard Time": "America/Bahia", + "UTC-02": "Etc/GMT+2", + "Azores Standard Time": "Atlantic/Azores", + "Cape Verde Standard Time": "Atlantic/Cape_Verde", + "UTC": "Etc/UTC", + "GMT Standard Time": "Europe/London", + "Greenwich Standard Time": "Atlantic/Reykjavik", + "W. Europe Standard Time": "Europe/Berlin", + "Central Europe Standard Time": "Europe/Budapest", + "Romance Standard Time": "Europe/Paris", + "Central European Standard Time": "Europe/Warsaw", + "W. Central Africa Standard Time": "Africa/Lagos", + "Jordan Standard Time": "Asia/Amman", + "GTB Standard Time": "Europe/Bucharest", + "Middle East Standard Time": "Asia/Beirut", + "Egypt Standard Time": "Africa/Cairo", + "E. Europe Standard Time": "Europe/Chisinau", + "Syria Standard Time": "Asia/Damascus", + "West Bank Standard Time": "Asia/Hebron", + "South Africa Standard Time": "Africa/Johannesburg", + "FLE Standard Time": "Europe/Kiev", + "Israel Standard Time": "Asia/Jerusalem", + "Kaliningrad Standard Time": "Europe/Kaliningrad", + "Sudan Standard Time": "Africa/Khartoum", + "Libya Standard Time": "Africa/Tripoli", + "Namibia Standard Time": "Africa/Windhoek", + "Arabic Standard Time": "Asia/Baghdad", + "Turkey Standard Time": "Europe/Istanbul", + "Arab Standard Time": "Asia/Riyadh", + "Belarus Standard Time": "Europe/Minsk", + "Russian Standard Time": "Europe/Moscow", + "E. Africa Standard Time": "Africa/Nairobi", + "Iran Standard Time": "Asia/Tehran", + "Arabian Standard Time": "Asia/Dubai", + "Astrakhan Standard Time": "Europe/Astrakhan", + "Azerbaijan Standard Time": "Asia/Baku", + "Russia Time Zone 3": "Europe/Samara", + "Mauritius Standard Time": "Indian/Mauritius", + "Saratov Standard Time": "Europe/Saratov", + "Georgian Standard Time": "Asia/Tbilisi", + "Volgograd Standard Time": "Europe/Volgograd", + "Caucasus Standard Time": "Asia/Yerevan", + "Afghanistan Standard Time": "Asia/Kabul", + "West Asia Standard Time": "Asia/Tashkent", + "Ekaterinburg Standard Time": "Asia/Yekaterinburg", + "Pakistan Standard Time": "Asia/Karachi", + "India Standard Time": "Asia/Calcutta", + "Sri Lanka Standard Time": "Asia/Colombo", + "Nepal Standard Time": "Asia/Kathmandu", + "Central Asia Standard Time": "Asia/Almaty", + "Bangladesh Standard Time": "Asia/Dhaka", + "Omsk Standard Time": "Asia/Omsk", + "Myanmar Standard Time": "Asia/Rangoon", + "SE Asia Standard Time": "Asia/Bangkok", + "Altai Standard Time": "Asia/Barnaul", + "W. Mongolia Standard Time": "Asia/Hovd", + "North Asia Standard Time": "Asia/Krasnoyarsk", + "N. Central Asia Standard Time": "Asia/Novosibirsk", + "Tomsk Standard Time": "Asia/Tomsk", + "China Standard Time": "Asia/Shanghai", + "North Asia East Standard Time": "Asia/Irkutsk", + "Singapore Standard Time": "Asia/Singapore", + "W. Australia Standard Time": "Australia/Perth", + "Taipei Standard Time": "Asia/Taipei", + "Ulaanbaatar Standard Time": "Asia/Ulaanbaatar", + "North Korea Standard Time": "Asia/Pyongyang", + "Aus Central W. Standard Time": "Australia/Eucla", + "Transbaikal Standard Time": "Asia/Chita", + "Tokyo Standard Time": "Asia/Tokyo", + "Korea Standard Time": "Asia/Seoul", + "Yakutsk Standard Time": "Asia/Yakutsk", + "Cen. Australia Standard Time": "Australia/Adelaide", + "AUS Central Standard Time": "Australia/Darwin", + "E. Australia Standard Time": "Australia/Brisbane", + "AUS Eastern Standard Time": "Australia/Sydney", + "West Pacific Standard Time": "Pacific/Port_Moresby", + "Tasmania Standard Time": "Australia/Hobart", + "Vladivostok Standard Time": "Asia/Vladivostok", + "Lord Howe Standard Time": "Australia/Lord_Howe", + "Bougainville Standard Time": "Pacific/Bougainville", + "Russia Time Zone 10": "Asia/Srednekolymsk", + "Magadan Standard Time": "Asia/Magadan", + "Norfolk Standard Time": "Pacific/Norfolk", + "Sakhalin Standard Time": "Asia/Sakhalin", + "Central Pacific Standard Time": "Pacific/Guadalcanal", + "Russia Time Zone 11": "Asia/Kamchatka", + "New Zealand Standard Time": "Pacific/Auckland", + "UTC+12": "Etc/GMT-12", + "Fiji Standard Time": "Pacific/Fiji", + "Chatham Islands Standard Time": "Pacific/Chatham", + "UTC+13": "Etc/GMT-13", + "Tonga Standard Time": "Pacific/Tongatapu", + "Samoa Standard Time": "Pacific/Apia", + "Line Islands Standard Time": "Pacific/Kiritimati", +} diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/models.py b/toolkits/outlook_calendar/arcade_outlook_calendar/models.py new file mode 100644 index 00000000..b65201b9 --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/models.py @@ -0,0 +1,288 @@ +import re +from dataclasses import dataclass, field +from typing import Any + +from bs4 import BeautifulSoup +from msgraph.generated.models.attendee import Attendee as GraphAttendee +from msgraph.generated.models.date_time_time_zone import DateTimeTimeZone as GraphDateTimeTimeZone +from msgraph.generated.models.email_address import EmailAddress as GraphEmailAddress +from msgraph.generated.models.event import Event as GraphEvent +from msgraph.generated.models.event_type import EventType as GraphEventType +from msgraph.generated.models.free_busy_status import FreeBusyStatus as GraphFreeBusyStatus +from msgraph.generated.models.importance import Importance as GraphImportance +from msgraph.generated.models.item_body import ItemBody as GraphItemBody +from msgraph.generated.models.location import Location as GraphLocation +from msgraph.generated.models.recipient import Recipient as GraphRecipient +from msgraph.generated.models.response_status import ResponseStatus as GraphResponseStatus +from msgraph.generated.models.response_type import ResponseType as GraphResponseType + + +@dataclass +class Attendee: + """An attendee of a calendar event.""" + + name: str = "" + address: str = "" + response: str = "" + + @classmethod + def from_sdk(cls, attendee: GraphAttendee) -> "Attendee": + """Convert a Microsoft Graph SDK Attendee object to an Attendee dataclass.""" + return cls( + name=attendee.email_address.name + if attendee.email_address and attendee.email_address.name + else "", + address=attendee.email_address.address + if attendee.email_address and attendee.email_address.address + else "", + response=attendee.status.response + if attendee.status and attendee.status.response + else "", + ) + + def to_dict(self) -> dict[str, str]: + return { + "name": self.name, + "address": self.address, + "response": self.response, + } + + def to_sdk(self) -> GraphAttendee: + """Convert an Attendee dataclass to a Microsoft Graph SDK Attendee object.""" + return GraphAttendee( + email_address=GraphEmailAddress(name=self.name, address=self.address), + status=GraphResponseStatus( + response=GraphResponseType(self.response) + if self.response + else GraphResponseType.None_ + ), + ) + + +@dataclass +class Organizer: + """The organizer of an event.""" + + name: str = "" + address: str = "" + + @classmethod + def from_sdk(cls, organizer: GraphRecipient) -> "Organizer": + """Convert a Microsoft Graph SDK Organizer object to an Organizer dataclass.""" + return cls( + name=organizer.email_address.name + if organizer.email_address and organizer.email_address.name + else "", + address=organizer.email_address.address + if organizer.email_address and organizer.email_address.address + else "", + ) + + def to_dict(self) -> dict[str, str]: + return { + "name": self.name, + "address": self.address, + } + + def to_sdk(self) -> GraphRecipient: + """Convert an Organizer dataclass to a Microsoft Graph SDK Organizer object.""" + recipient = GraphRecipient( + email_address=GraphEmailAddress(name=self.name, address=self.address) + ) + return recipient + + +@dataclass +class DateTimeTimeZone: + """Time information for an event.""" + + date_time: str = "" + time_zone: str = "" + + @classmethod + def from_sdk(cls, date_time_time_zone: GraphDateTimeTimeZone) -> "DateTimeTimeZone": + """Convert a Microsoft Graph SDK DateTimeTimeZone object to a TimeInfo dataclass.""" + return cls( + date_time=date_time_time_zone.date_time or "", + time_zone=date_time_time_zone.time_zone or "", + ) + + def to_dict(self) -> dict[str, str]: + return { + "dateTime": self.date_time, + "timeZone": self.time_zone, + } + + def to_sdk(self) -> GraphDateTimeTimeZone: + """Convert a TimeInfo dataclass to a Microsoft Graph SDK DateTimeTimeZone object.""" + return GraphDateTimeTimeZone(date_time=self.date_time, time_zone=self.time_zone) + + +@dataclass +class ResponseStatus: + """The response status for an event.""" + + response: str = "" + + @classmethod + def from_sdk(cls, response_status: GraphResponseStatus) -> "ResponseStatus": + """Convert a Microsoft Graph SDK ResponseStatus object to a ResponseStatus dataclass.""" + response_value = ( + str(response_status.response.value) + if response_status.response and hasattr(response_status.response, "value") + else "" + ) + return cls(response=response_value) + + def to_dict(self) -> dict[str, str]: + return { + "response": self.response, + } + + def to_sdk(self) -> GraphResponseStatus: + """Convert a ResponseStatus dataclass to a Microsoft Graph SDK ResponseStatus object.""" + return GraphResponseStatus(response=GraphResponseType(self.response)) + + +@dataclass +class Event: + """A calendar event in Outlook.""" + + attendees: list[Attendee] = field(default_factory=list) + body: str = "" + end: DateTimeTimeZone | None = None + has_attachments: bool = False + importance: str = "" + is_all_day: bool = False + is_cancelled: bool = False + is_draft: bool = False + is_online_meeting: bool = False + is_organizer: bool = False + location: str = "" + online_meeting_url: str = "" + organizer: Organizer | None = None + id: str = "" + response_status: ResponseStatus | None = None + show_as: str = "" + start: DateTimeTimeZone | None = None + subject: str = "" + type: str = "" + web_link: str = "" + event_id: str = "" # The unique identifier of the event. Read-only. + + @staticmethod + def _safe_str(value: Any) -> str: + if not value: + return "" + if isinstance(value, bytes | bytearray): + return value.decode("utf-8", errors="ignore") + return str(value) + + @staticmethod + def _safe_bool(value: Any) -> bool: + return bool(value) + + @staticmethod + def _parse_body(mime: str) -> str: + if not mime: + return "" + soup = BeautifulSoup(mime, "html.parser") + text = soup.get_text(separator=" ") + # Replace multiple newlines with a single newline + text = re.sub(r"\n+", "\n", text) + # Replace multiple spaces with a single space + text = re.sub(r"\s+", " ", text) + # Replace sequences of dots (likely from horizontal lines) with a single newline + text = re.sub(r"\.{3,}", "\n---\n", text) + # Remove leading/trailing whitespace from each line + text = "\n".join(line.strip() for line in text.split("\n")) + return text + + @classmethod + def from_sdk(cls, event: GraphEvent) -> "Event": + """Convert a Microsoft Graph SDK Event object to an Event dataclass.""" + body_mime = event.body.content if event.body and event.body.content else "" + body = cls._parse_body(body_mime) + + attendees = [Attendee.from_sdk(a) for a in event.attendees if a] if event.attendees else [] + start = DateTimeTimeZone.from_sdk(event.start) if event.start else None + end = DateTimeTimeZone.from_sdk(event.end) if event.end else None + organizer = Organizer.from_sdk(event.organizer) if event.organizer else None + response_status = ( + ResponseStatus.from_sdk(event.response_status) if event.response_status else None + ) + + return cls( + attendees=attendees, + body=body, + end=end, + has_attachments=cls._safe_bool(event.has_attachments), + importance=cls._safe_str(str(event.importance.value)) if event.importance else "", + is_all_day=cls._safe_bool(event.is_all_day), + is_cancelled=cls._safe_bool(event.is_cancelled), + is_draft=cls._safe_bool(event.is_draft), + is_online_meeting=cls._safe_bool(event.is_online_meeting), + is_organizer=cls._safe_bool(event.is_organizer), + location=cls._safe_str(event.location.display_name if event.location else ""), + online_meeting_url=cls._safe_str(event.online_meeting_url), + organizer=organizer, + id=cls._safe_str(event.id), + response_status=response_status, + show_as=cls._safe_str(str(event.show_as.value)) if event.show_as else "", + start=start, + subject=cls._safe_str(event.subject), + type=cls._safe_str(str(event.type.value)) if event.type else "", + web_link=cls._safe_str(event.web_link), + event_id=cls._safe_str(event.id), + ) + + def to_dict(self) -> dict[str, Any]: + """Converts the Event dataclass to a dictionary.""" + return { + "attendees": [attendee.to_dict() for attendee in self.attendees], + "body": self.body, + "end": self.end.to_dict() if self.end else None, + "has_attachments": self.has_attachments, + "importance": self.importance, + "is_all_day": self.is_all_day, + "is_cancelled": self.is_cancelled, + "is_draft": self.is_draft, + "is_online_meeting": self.is_online_meeting, + "is_organizer": self.is_organizer, + "location": self.location, + "online_meeting_url": self.online_meeting_url, + "organizer": self.organizer.to_dict() if self.organizer else None, + "id": self.id, + "response_status": self.response_status.to_dict() if self.response_status else None, + "show_as": self.show_as, + "start": self.start.to_dict() if self.start else None, + "subject": self.subject, + "type": self.type, + "web_link": self.web_link, + "event_id": self.event_id, + } + + def to_sdk(self) -> GraphEvent: + """Convert an Event dataclass to a Microsoft Graph SDK Event object.""" + return GraphEvent( + attendees=[attendee.to_sdk() for attendee in self.attendees], + body=GraphItemBody(content=self.body), + end=self.end.to_sdk() if self.end else None, + has_attachments=self.has_attachments, + importance=GraphImportance(self.importance) if self.importance else None, + is_all_day=self.is_all_day, + is_cancelled=self.is_cancelled, + is_draft=self.is_draft, + is_online_meeting=self.is_online_meeting, + is_organizer=self.is_organizer, + location=GraphLocation(display_name=self.location), + online_meeting_url=self.online_meeting_url, + organizer=self.organizer.to_sdk() if self.organizer else None, + id=self.id, + response_status=self.response_status.to_sdk() if self.response_status else None, + show_as=GraphFreeBusyStatus(self.show_as) if self.show_as else None, + start=self.start.to_sdk() if self.start else None, + subject=self.subject, + type=GraphEventType(self.type) if self.type else None, + web_link=self.web_link, + ) diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/tools/__init__.py b/toolkits/outlook_calendar/arcade_outlook_calendar/tools/__init__.py new file mode 100644 index 00000000..f4dde4c0 --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/tools/__init__.py @@ -0,0 +1,7 @@ +from arcade_outlook_calendar.tools.create_event import create_event +from arcade_outlook_calendar.tools.get_event import get_event +from arcade_outlook_calendar.tools.list_events_in_time_range import ( + list_events_in_time_range, +) + +__all__ = ["create_event", "get_event", "list_events_in_time_range"] diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/tools/create_event.py b/toolkits/outlook_calendar/arcade_outlook_calendar/tools/create_event.py new file mode 100644 index 00000000..562a1acf --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/tools/create_event.py @@ -0,0 +1,81 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Microsoft + +from arcade_outlook_calendar._utils import ( + create_timezone_request_config, + get_default_calendar_timezone, + prepare_meeting_body, + remove_timezone_offset, + validate_date_times, + validate_emails, +) +from arcade_outlook_calendar.client import get_client +from arcade_outlook_calendar.models import ( + Attendee, + DateTimeTimeZone, + Event, +) + + +@tool(requires_auth=Microsoft(scopes=["MailboxSettings.Read", "Calendars.ReadWrite"])) +async def create_event( + context: ToolContext, + subject: Annotated[str, "The text of the event's subject (title) line."], + body: Annotated[str, "The body of the event"], + start_date_time: Annotated[ + str, + "The datetime of the event's start, represented in " + "ISO 8601 format. Timezone offset is ignored. For example, 2025-04-25T13:00:00", + ], + end_date_time: Annotated[ + str, + "The datetime of the event's end, represented in " + "ISO 8601 format. Timezone offset is ignored. For example, 2025-04-25T13:30:00", + ], + location: Annotated[str | None, "The location of the event"] = None, + attendee_emails: Annotated[ + list[str] | None, + "The email addresses of the attendees of the event. " + "Must be valid email addresses e.g., username@domain.com.", + ] = None, + is_online_meeting: Annotated[ + bool, "Whether the event is an online meeting. Defaults to False" + ] = False, + custom_meeting_url: Annotated[ + str | None, + "The URL of the online meeting. If not provided and is_online_meeting is True, " + "then a url will be generated for you", + ] = None, +) -> Annotated[dict, "A dictionary containing the created event details"]: + """Create an event in the authenticated user's default calendar. + + Ignores timezone offsets provided in the start_date_time and end_date_time parameters. + Instead, uses the user's default calendar timezone to filter events. + If the user has not set a timezone for their calendar, then the timezone will be UTC. + """ + # Validate & cleanse inputs + validate_emails(attendee_emails or []) + validate_date_times(start_date_time, end_date_time) + body, is_online_meeting = prepare_meeting_body(body, custom_meeting_url, is_online_meeting) + + client = get_client(context.get_auth_token_or_empty()) + + time_zone = await get_default_calendar_timezone(client) + start_date_time = remove_timezone_offset(start_date_time) + end_date_time = remove_timezone_offset(end_date_time) + event = Event( + subject=subject, + body=body, + start=DateTimeTimeZone(date_time=start_date_time, time_zone=time_zone), + end=DateTimeTimeZone(date_time=end_date_time, time_zone=time_zone), + location=location or "", + attendees=[Attendee(address=attendee) for attendee in attendee_emails or []], + is_online_meeting=is_online_meeting, + ).to_sdk() + request_config = create_timezone_request_config(time_zone) + + response = await client.me.events.post(body=event, request_configuration=request_config) + + return Event.from_sdk(response).to_dict() # type: ignore[arg-type] diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/tools/get_event.py b/toolkits/outlook_calendar/arcade_outlook_calendar/tools/get_event.py new file mode 100644 index 00000000..901ce251 --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/tools/get_event.py @@ -0,0 +1,29 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Microsoft + +from arcade_outlook_calendar._utils import ( + create_timezone_request_config, + get_default_calendar_timezone, +) +from arcade_outlook_calendar.client import get_client +from arcade_outlook_calendar.models import Event + + +@tool(requires_auth=Microsoft(scopes=["MailboxSettings.Read", "Calendars.ReadBasic"])) +async def get_event( + context: ToolContext, + event_id: Annotated[str, "The ID of the event to get"], +) -> Annotated[dict, "A dictionary containing the event details"]: + """Get an event by its ID from the user's calendar.""" + client = get_client(context.get_auth_token_or_empty()) + + time_zone = await get_default_calendar_timezone(client) + request_config = create_timezone_request_config(time_zone) + + response = await client.me.events.by_event_id(event_id).get( + request_configuration=request_config + ) + + return Event.from_sdk(response).to_dict() # type: ignore[arg-type] diff --git a/toolkits/outlook_calendar/arcade_outlook_calendar/tools/list_events_in_time_range.py b/toolkits/outlook_calendar/arcade_outlook_calendar/tools/list_events_in_time_range.py new file mode 100644 index 00000000..d33d32fd --- /dev/null +++ b/toolkits/outlook_calendar/arcade_outlook_calendar/tools/list_events_in_time_range.py @@ -0,0 +1,59 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Microsoft +from msgraph.generated.users.item.calendar.calendar_view.calendar_view_request_builder import ( + CalendarViewRequestBuilder, +) + +from arcade_outlook_calendar._utils import ( + convert_timezone_to_offset, + create_timezone_request_config, + get_default_calendar_timezone, + replace_timezone_offset, + validate_date_times, +) +from arcade_outlook_calendar.client import get_client +from arcade_outlook_calendar.models import Event + + +@tool(requires_auth=Microsoft(scopes=["MailboxSettings.Read", "Calendars.ReadBasic"])) +async def list_events_in_time_range( + context: ToolContext, + start_date_time: Annotated[ + str, + "The start date and time of the time range, represented in " + "ISO 8601 format. Timezone offset is ignored. For example, 2025-04-24T19:00:00", + ], + end_date_time: Annotated[ + str, + "The end date and time of the time range, represented in " + "ISO 8601 format. Timezone offset is ignored. For example, 2025-04-24T19:30:00", + ], + limit: Annotated[int, "The maximum number of events to return. Max 1000. Defaults to 10"] = 10, +) -> Annotated[dict, "A dictionary containing a list of events"]: + """List events in the user's calendar in a specific time range. + + Ignores timezone offsets provided in the start_date_time and end_date_time parameters. + Instead, uses the user's default calendar timezone to filter events. + If the user has not set a timezone for their calendar, then the timezone will be UTC. + """ + # Validate inputs + validate_date_times(start_date_time, end_date_time) + + client = get_client(context.get_auth_token_or_empty()) + time_zone = await get_default_calendar_timezone(client) + time_zone_offset = convert_timezone_to_offset(time_zone) + start_date_time = replace_timezone_offset(start_date_time, time_zone_offset) + end_date_time = replace_timezone_offset(end_date_time, time_zone_offset) + query_params = CalendarViewRequestBuilder.CalendarViewRequestBuilderGetQueryParameters( + start_date_time=start_date_time, + end_date_time=end_date_time, + top=max(1, min(limit, 1000)), + ) + request_config = create_timezone_request_config(time_zone, query_params) + + response = await client.me.calendar.calendar_view.get(request_config) + events = [Event.from_sdk(event).to_dict() for event in response.value] # type: ignore[union-attr] + + return {"events": events, "num_events": len(events)} diff --git a/toolkits/outlook_calendar/evals/additional_messages.py b/toolkits/outlook_calendar/evals/additional_messages.py new file mode 100644 index 00000000..ab54780e --- /dev/null +++ b/toolkits/outlook_calendar/evals/additional_messages.py @@ -0,0 +1,28 @@ +get_event_additional_messages = [ + {"role": "system", "content": "Today is 2025-04-22, Tuesday."}, + {"role": "user", "content": "show me my meetings for today"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_def456", + "type": "function", + "function": { + "name": "Microsoft_ListEventsInTimeRange", + "arguments": '{"start_date_time":"2025-04-22T00:00:00","end_date_time":"2025-04-22T23:59:59"}', + }, + } + ], + }, + { + "role": "tool", + "content": '{"events":[{"attendees":[{"email":"john@example.com","name":"John Smith"},{"email":"alice@example.com","name":"Alice Johnson"}],"body":"Quarterly review meeting","end":{"date_time":"2025-04-22T15:00:00.0000000","time_zone":"Pacific Standard Time"},"has_attachments":true,"id":"AAMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMBAFuxokOLZRtDncM4","importance":"high","is_all_day":false,"is_cancelled":false,"is_online_meeting":true,"is_organizer":true,"location":"Online","online_meeting_url":"https://teams.microsoft.com/l/meetup-join/meeting_id","organizer":{"email":"user@example.com","name":"User Name"},"start":{"date_time":"2025-04-22T14:00:00.0000000","time_zone":"Pacific Standard Time"},"subject":"Q1 Review","web_link":"https://outlook.office365.com/owa/?itemid=AAMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMBAFuxokOLZRtDncM4&exvsurl=1&path=/calendar/item"}],"num_events":1}', + "tool_call_id": "call_def456", + "name": "Microsoft_ListEventsInTimeRange", + }, + { + "role": "assistant", + "content": "You have 1 meeting scheduled for today:\n\n1. **Q1 Review** - Today, 2:00 PM - 3:00 PM\n Location: Online\n Attendees: John Smith, Alice Johnson\n This is a high importance meeting with attachments.", + }, +] diff --git a/toolkits/outlook_calendar/evals/eval_create_event.py b/toolkits/outlook_calendar/evals/eval_create_event.py new file mode 100644 index 00000000..b1eefd13 --- /dev/null +++ b/toolkits/outlook_calendar/evals/eval_create_event.py @@ -0,0 +1,94 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +from arcade_outlook_calendar import create_event + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_tool(create_event, "OutlookCalendar") + + +@tool_eval() +def outlook_calendar_create_event_eval_suite() -> EvalSuite: + """Create an evaluation suite for Outlook Calendar create event tool.""" + suite = EvalSuite( + name="Outlook Calendar Create Event Evaluation", + system_message=( + "You are an AI that has access to tools to view and manage calendar events. " + "The current time date and time is April 25, 2025, 5:18 PM PST." + ), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Create virtual event", + user_message=( + "schedule a virtual team meeting 'Standup' tomorrow at 3pm for 1 hour. " + "john@example.com and sarah@example.com need to be there" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_event, + args={ + "subject": "Standup", + "start_date_time": "2025-04-26T15:00:00", + "end_date_time": "2025-04-26T16:00:00", + "attendee_emails": ["john@example.com", "sarah@example.com"], + "is_online_meeting": True, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=1 / 5), + BinaryCritic(critic_field="start_date_time", weight=1 / 5), + BinaryCritic(critic_field="end_date_time", weight=1 / 5), + BinaryCritic(critic_field="attendee_emails", weight=1 / 5), + BinaryCritic(critic_field="is_online_meeting", weight=1 / 5), + ], + ) + + suite.add_case( + name="Create event with physical location and virtual link", + user_message=( + "schedule a team meeting 'All hands' tomorrow at 3pm for 1 hour. " + "john@example.com and sarah@example.com need to be there. " + "The meeting will be in Conference Room A, but there will be a virtual link " + "for those who cannot attend in person." + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_event, + args={ + "subject": "All hands", + "start_date_time": "2025-04-26T15:00:00", + "end_date_time": "2025-04-26T16:00:00", + "location": "Conference Room A", + "attendee_emails": ["john@example.com", "sarah@example.com"], + "is_online_meeting": True, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=1 / 6), + BinaryCritic(critic_field="start_date_time", weight=1 / 6), + BinaryCritic(critic_field="end_date_time", weight=1 / 6), + SimilarityCritic(critic_field="location", weight=1 / 6), + BinaryCritic(critic_field="attendee_emails", weight=1 / 6), + BinaryCritic(critic_field="is_online_meeting", weight=1 / 6), + ], + ) + + return suite diff --git a/toolkits/outlook_calendar/evals/eval_get_event.py b/toolkits/outlook_calendar/evals/eval_get_event.py new file mode 100644 index 00000000..d2274c70 --- /dev/null +++ b/toolkits/outlook_calendar/evals/eval_get_event.py @@ -0,0 +1,53 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +from arcade_outlook_calendar import get_event +from evals.additional_messages import get_event_additional_messages + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_tool(get_event, "OutlookCalendar") + + +@tool_eval() +def outlook_calendar_get_event_eval_suite() -> EvalSuite: + """Create an evaluation suite for Outlook Calendar get event tool.""" + suite = EvalSuite( + name="Outlook Calendar Get Event Evaluation", + system_message=( + "You are an AI that has access to tools to view and manage calendar events. " + "The current time date and time is April 25, 2025, 5:18 PM PST." + ), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Get event by id after listing events", + user_message="tell me more about the first event", + expected_tool_calls=[ + ExpectedToolCall( + func=get_event, + args={ + "event_id": "AAMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMBAFuxokOLZRtDncM4", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="event_id", weight=1.0), + ], + additional_messages=get_event_additional_messages, + ) + + return suite diff --git a/toolkits/outlook_calendar/evals/eval_list_events_in_time_range.py b/toolkits/outlook_calendar/evals/eval_list_events_in_time_range.py new file mode 100644 index 00000000..a917516e --- /dev/null +++ b/toolkits/outlook_calendar/evals/eval_list_events_in_time_range.py @@ -0,0 +1,77 @@ +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +from arcade_outlook_calendar import list_events_in_time_range + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + +catalog = ToolCatalog() +catalog.add_tool(list_events_in_time_range, "OutlookCalendar") + + +@tool_eval() +def outlook_calendar_list_events_in_time_range_eval_suite() -> EvalSuite: + """Create an evaluation suite for Outlook Calendar list events tool.""" + suite = EvalSuite( + name="Outlook Calendar List Events Evaluation", + system_message=( + "You are an AI that has access to tools to view and manage calendar events. " + "The current time date and time is Friday, April 25, 2025, 5:18 PM PST." + ), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="List events in time range", + user_message="what are my meetings on monday", + expected_tool_calls=[ + ExpectedToolCall( + func=list_events_in_time_range, + args={ + "start_date_time": "2025-04-28T00:00:00", + "end_date_time": "2025-04-28T23:59:59", + }, + ) + ], + critics=[ + DatetimeCritic(critic_field="start_date_time", weight=0.5), + DatetimeCritic(critic_field="end_date_time", weight=0.5), + ], + ) + + suite.add_case( + name="List events in time range with limit", + user_message=( + "get my first 10 meetings for the next work-week through thursday, " + "starting tuesday (mon is holiday)" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=list_events_in_time_range, + args={ + "start_date_time": "2025-04-29T00:00:00", + "end_date_time": "2025-05-01T23:59:59", + "limit": 10, + }, + ) + ], + critics=[ + DatetimeCritic(critic_field="start_date_time", weight=0.3), + DatetimeCritic(critic_field="end_date_time", weight=0.3), + BinaryCritic(critic_field="limit", weight=0.4), + ], + ) + + return suite diff --git a/toolkits/outlook_calendar/pyproject.toml b/toolkits/outlook_calendar/pyproject.toml new file mode 100644 index 00000000..a9f7fb48 --- /dev/null +++ b/toolkits/outlook_calendar/pyproject.toml @@ -0,0 +1,61 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_outlook_calendar" +version = "1.0.0" +description = "rcade.dev LLM tools for Outlook Calendar" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "msgraph-sdk>=1.28.0,<2.0.0", + "beautifulsoup4>=4.10.0,<5.0.0", + "pytz>=2024.2,<2025.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<9.0.0", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.11.1,<4.0.0", + "pytest-asyncio>=0.24.0,<1.0.0", + "mypy>=1.5.1,<2.0.0", + "pre-commit>=3.4.0,<4.0.0", + "tox>=4.11.1,<5.0.0", + "ruff>=0.7.4,<1.0.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_outlook_calendar/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_outlook_calendar",] diff --git a/toolkits/outlook_calendar/tests/__init__.py b/toolkits/outlook_calendar/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/outlook_calendar/tests/test_models.py b/toolkits/outlook_calendar/tests/test_models.py new file mode 100644 index 00000000..480229c9 --- /dev/null +++ b/toolkits/outlook_calendar/tests/test_models.py @@ -0,0 +1,385 @@ +import pytest +from msgraph.generated.models.attendee import Attendee as GraphAttendee +from msgraph.generated.models.date_time_time_zone import DateTimeTimeZone as GraphDateTimeTimeZone +from msgraph.generated.models.email_address import EmailAddress as GraphEmailAddress +from msgraph.generated.models.event import Event as GraphEvent +from msgraph.generated.models.location import Location as GraphLocation +from msgraph.generated.models.recipient import Recipient as GraphRecipient +from msgraph.generated.models.response_status import ResponseStatus as GraphResponseStatus +from msgraph.generated.models.response_type import ResponseType as GraphResponseType + +from arcade_outlook_calendar.models import ( + Attendee, + DateTimeTimeZone, + Event, + Organizer, + ResponseStatus, +) + + +class DummyBody: + def __init__(self, content): + self.content = content + + +class DummyEventType: + def __init__(self, value): + self.value = value + + +class DummyImportance: + def __init__(self, value): + self.value = value + + +class DummyFreeBusyStatus: + def __init__(self, value): + self.value = value + + +class DummyResponseType: + def __init__(self, value): + self.value = value + + +@pytest.mark.parametrize( + "input_data, expected", + [ + ( + {"name": "John Doe", "address": "john.doe@example.com", "response": "accepted"}, + {"name": "John Doe", "address": "john.doe@example.com", "response": "accepted"}, + ), + ( + {"name": "", "address": "anonymous@example.com", "response": "tentativelyAccepted"}, + {"name": "", "address": "anonymous@example.com", "response": "tentativelyAccepted"}, + ), + ( + {"name": None, "address": None, "response": "none"}, + {"name": "", "address": "", "response": "none"}, + ), + ], +) +def test_attendee_conversion(input_data, expected): + sdk_attendee = GraphAttendee() + sdk_attendee.email_address = GraphEmailAddress() + sdk_attendee.email_address.name = input_data["name"] + sdk_attendee.email_address.address = input_data["address"] + sdk_attendee.status = GraphResponseStatus() + sdk_attendee.status.response = GraphResponseType(input_data["response"]) + + # Test from_sdk method + attendee = Attendee.from_sdk(sdk_attendee) + assert attendee.name == expected["name"] + assert attendee.address == expected["address"] + assert attendee.response == expected["response"] + + # Test to_dict method + dict_result = attendee.to_dict() + assert dict_result == expected + + # Test to_sdk method + sdk_result = attendee.to_sdk() + assert sdk_result.email_address.name == expected["name"] + assert sdk_result.email_address.address == expected["address"] + assert sdk_result.status.response == GraphResponseType(expected["response"]) + + +@pytest.mark.parametrize( + "input_data, expected", + [ + ( + {"name": "Jane Smith", "address": "jane.smith@example.com"}, + {"name": "Jane Smith", "address": "jane.smith@example.com"}, + ), + ( + {"name": "", "address": "unknown@example.com"}, + {"name": "", "address": "unknown@example.com"}, + ), + ({"name": None, "address": None}, {"name": "", "address": ""}), + ], +) +def test_organizer_conversion(input_data, expected): + sdk_organizer = GraphRecipient() + sdk_organizer.email_address = GraphEmailAddress() + sdk_organizer.email_address.name = input_data["name"] + sdk_organizer.email_address.address = input_data["address"] + + # Test from_sdk method + organizer = Organizer.from_sdk(sdk_organizer) + assert organizer.name == expected["name"] + assert organizer.address == expected["address"] + + # Test to_dict method + dict_result = organizer.to_dict() + assert dict_result == expected + + # Test to_sdk method + sdk_result = organizer.to_sdk() + assert sdk_result.email_address.name == expected["name"] + assert sdk_result.email_address.address == expected["address"] + + +@pytest.mark.parametrize( + "input_data, expected", + [ + ( + {"date_time": "2023-05-10T14:00:00", "time_zone": "Pacific Standard Time"}, + {"dateTime": "2023-05-10T14:00:00", "timeZone": "Pacific Standard Time"}, + ), + ({"date_time": "", "time_zone": "UTC"}, {"dateTime": "", "timeZone": "UTC"}), + ({"date_time": None, "time_zone": None}, {"dateTime": "", "timeZone": ""}), + ], +) +def test_date_time_time_zone_conversion(input_data, expected): + sdk_date_time = GraphDateTimeTimeZone() + sdk_date_time.date_time = input_data["date_time"] + sdk_date_time.time_zone = input_data["time_zone"] + + # Test from_sdk method + date_time_tz = DateTimeTimeZone.from_sdk(sdk_date_time) + assert date_time_tz.date_time == (input_data["date_time"] or "") + assert date_time_tz.time_zone == (input_data["time_zone"] or "") + + # Test to_dict method + dict_result = date_time_tz.to_dict() + assert dict_result == expected + + # Test to_sdk method + sdk_result = date_time_tz.to_sdk() + assert sdk_result.date_time == date_time_tz.date_time + assert sdk_result.time_zone == date_time_tz.time_zone + + +@pytest.mark.parametrize( + "input_data, expected", + [ + ({"response": "accepted"}, {"response": "accepted"}), + ({"response": "declined"}, {"response": "declined"}), + ({"response": "none"}, {"response": "none"}), + ], +) +def test_response_status_conversion(input_data, expected): + sdk_response_status = GraphResponseStatus() + sdk_response_status.response = GraphResponseType(input_data["response"]) + + # Test from_sdk method + response_status = ResponseStatus.from_sdk(sdk_response_status) + assert response_status.response == expected["response"] + + # Test to_dict method + dict_result = response_status.to_dict() + assert dict_result == expected + + # Test to_sdk method + sdk_result = response_status.to_sdk() + assert sdk_result.response == GraphResponseType(expected["response"]) + + +@pytest.mark.parametrize( + "input_data, expected", + [ + ( + { + "body_content": "

Team Meeting

", + "has_attachments": True, + "importance": "high", + "is_all_day": False, + "is_cancelled": False, + "is_draft": False, + "is_online_meeting": True, + "is_organizer": True, + "location": "Conference Room A", + "online_meeting_url": "https://teams.microsoft.com/l/meetup-join/123", + "id": "event-123", + "show_as": "busy", + "subject": "Weekly Team Sync", + "type": "singleInstance", + "web_link": "https://outlook.office.com/calendar/item/123", + "attendees": [ + {"name": "Alice", "address": "alice@example.com", "response": "accepted"}, + { + "name": "Bob", + "address": "bob@example.com", + "response": "tentativelyAccepted", + }, + ], + "organizer": {"name": "Manager", "address": "manager@example.com"}, + "start": {"date_time": "2023-05-10T10:00:00", "time_zone": "Eastern Standard Time"}, + "end": {"date_time": "2023-05-10T11:00:00", "time_zone": "Eastern Standard Time"}, + "response_status": {"response": "accepted"}, + }, + { + "body": "Team Meeting", + "has_attachments": True, + "importance": "high", + "is_all_day": False, + "is_cancelled": False, + "is_draft": False, + "is_online_meeting": True, + "is_organizer": True, + "location": "Conference Room A", + "online_meeting_url": "https://teams.microsoft.com/l/meetup-join/123", + "id": "event-123", + "show_as": "busy", + "subject": "Weekly Team Sync", + "type": "singleInstance", + "web_link": "https://outlook.office.com/calendar/item/123", + "event_id": "event-123", + "attendees": [ + {"name": "Alice", "address": "alice@example.com", "response": "accepted"}, + { + "name": "Bob", + "address": "bob@example.com", + "response": "tentativelyAccepted", + }, + ], + "organizer": {"name": "Manager", "address": "manager@example.com"}, + "start": {"dateTime": "2023-05-10T10:00:00", "timeZone": "Eastern Standard Time"}, + "end": {"dateTime": "2023-05-10T11:00:00", "timeZone": "Eastern Standard Time"}, + "response_status": {"response": "accepted"}, + }, + ), + ( + { + "body_content": "

All day event description

", + "has_attachments": False, + "importance": "normal", + "is_all_day": True, + "is_cancelled": True, + "is_draft": True, + "is_online_meeting": False, + "is_organizer": False, + "location": "", + "online_meeting_url": "", + "id": "event-456", + "show_as": "free", + "subject": "Company Holiday", + "type": "occurrence", + "web_link": "https://outlook.office.com/calendar/item/456", + "attendees": [], + "organizer": {"name": "HR Department", "address": "hr@example.com"}, + "start": {"date_time": "2023-07-04T00:00:00", "time_zone": "UTC"}, + "end": {"date_time": "2023-07-05T00:00:00", "time_zone": "UTC"}, + "response_status": {"response": "notResponded"}, + }, + { + "body": "All day event description", + "has_attachments": False, + "importance": "normal", + "is_all_day": True, + "is_cancelled": True, + "is_draft": True, + "is_online_meeting": False, + "is_organizer": False, + "location": "", + "online_meeting_url": "", + "id": "event-456", + "show_as": "free", + "subject": "Company Holiday", + "type": "occurrence", + "web_link": "https://outlook.office.com/calendar/item/456", + "event_id": "event-456", + "attendees": [], + "organizer": {"name": "HR Department", "address": "hr@example.com"}, + "start": {"dateTime": "2023-07-04T00:00:00", "timeZone": "UTC"}, + "end": {"dateTime": "2023-07-05T00:00:00", "timeZone": "UTC"}, + "response_status": {"response": "notResponded"}, + }, + ), + ], +) +def test_event_conversion(input_data, expected): + def make_graph_attendee(attendee_data): + attendee = GraphAttendee() + attendee.email_address = GraphEmailAddress() + attendee.email_address.name = attendee_data.get("name", "") + attendee.email_address.address = attendee_data.get("address", "") + attendee.status = GraphResponseStatus() + attendee.status.response = GraphResponseType(attendee_data.get("response", "")) + return attendee + + def make_graph_organizer(organizer_data): + organizer = GraphRecipient() + organizer.email_address = GraphEmailAddress() + organizer.email_address.name = organizer_data.get("name", "") + organizer.email_address.address = organizer_data.get("address", "") + return organizer + + def make_graph_date_time(date_time_data): + date_time = GraphDateTimeTimeZone() + date_time.date_time = date_time_data.get("date_time", "") + date_time.time_zone = date_time_data.get("time_zone", "") + return date_time + + sdk_event = GraphEvent() + sdk_event.body = DummyBody(input_data["body_content"]) + sdk_event.has_attachments = input_data["has_attachments"] + sdk_event.importance = DummyImportance(input_data["importance"]) + sdk_event.is_all_day = input_data["is_all_day"] + sdk_event.is_cancelled = input_data["is_cancelled"] + sdk_event.is_draft = input_data["is_draft"] + sdk_event.is_online_meeting = input_data["is_online_meeting"] + sdk_event.is_organizer = input_data["is_organizer"] + sdk_event.location = GraphLocation(display_name=input_data["location"]) + sdk_event.online_meeting_url = input_data["online_meeting_url"] + sdk_event.id = input_data["id"] + sdk_event.show_as = DummyFreeBusyStatus(input_data["show_as"]) + sdk_event.subject = input_data["subject"] + sdk_event.type = DummyEventType(input_data["type"]) + sdk_event.web_link = input_data["web_link"] + sdk_event.attendees = [make_graph_attendee(a) for a in input_data["attendees"]] + sdk_event.organizer = make_graph_organizer(input_data["organizer"]) + sdk_event.start = make_graph_date_time(input_data["start"]) + sdk_event.end = make_graph_date_time(input_data["end"]) + sdk_event.response_status = GraphResponseStatus() + sdk_event.response_status.response = GraphResponseType( + input_data["response_status"]["response"] + ) + + # Test from_sdk method + event = Event.from_sdk(sdk_event) + assert event.body == expected["body"] + assert event.has_attachments == expected["has_attachments"] + assert event.importance == expected["importance"] + assert event.is_all_day == expected["is_all_day"] + assert event.is_cancelled == expected["is_cancelled"] + assert event.is_draft == expected["is_draft"] + assert event.is_online_meeting == expected["is_online_meeting"] + assert event.is_organizer == expected["is_organizer"] + assert event.location == expected["location"] + assert event.online_meeting_url == expected["online_meeting_url"] + assert event.id == expected["id"] + assert event.show_as == expected["show_as"] + assert event.subject == expected["subject"] + assert event.type == expected["type"] + assert event.web_link == expected["web_link"] + assert event.event_id == expected["event_id"] + assert len(event.attendees) == len(expected["attendees"]) + for i, attendee in enumerate(event.attendees): + assert attendee.name == expected["attendees"][i]["name"] + assert attendee.address == expected["attendees"][i]["address"] + assert attendee.response == expected["attendees"][i]["response"] + if event.start: + assert event.start.date_time == expected["start"]["dateTime"] + assert event.start.time_zone == expected["start"]["timeZone"] + if event.end: + assert event.end.date_time == expected["end"]["dateTime"] + assert event.end.time_zone == expected["end"]["timeZone"] + if event.organizer: + assert event.organizer.name == expected["organizer"]["name"] + assert event.organizer.address == expected["organizer"]["address"] + if event.response_status: + assert event.response_status.response == expected["response_status"]["response"] + + # Test to_dict method + dict_result = event.to_dict() + assert dict_result["body"] == expected["body"] + assert dict_result["subject"] == expected["subject"] + assert dict_result["event_id"] == expected["event_id"] + + # Test to_sdk method + sdk_result = event.to_sdk() + assert sdk_result.subject == event.subject + assert sdk_result.is_all_day == event.is_all_day + assert sdk_result.location.display_name == event.location + assert len(sdk_result.attendees) == len(event.attendees) diff --git a/toolkits/outlook_calendar/tests/test_utils.py b/toolkits/outlook_calendar/tests/test_utils.py new file mode 100644 index 00000000..6fc59b18 --- /dev/null +++ b/toolkits/outlook_calendar/tests/test_utils.py @@ -0,0 +1,118 @@ +import pytest +from arcade_tdk.errors import ToolExecutionError + +from arcade_outlook_calendar._utils import ( + convert_timezone_to_offset, + is_valid_email, + remove_timezone_offset, + replace_timezone_offset, + validate_date_times, + validate_emails, +) + + +@pytest.mark.parametrize( + "start_date_time, end_date_time, error_type", + [ + ( + "2026-01-01T10:00:00", + "2026-01-01T17:00:00", + None, + ), + # end_date_time before start_date_time + ( + "2026-01-01T10:00:00", + "2026-01-01T10:00:00", + ToolExecutionError, + ), + # end_date_time before start_date_time because timezone offset is ignored + ( + "2026-01-01T10:00:00-07:00", + "2026-01-01T09:00:00-08:00", + ToolExecutionError, + ), + # not ISO 8601 format + ( + "20260101T10:00:00", + "2026-01-0109:00:00", + ValueError, + ), + ], +) +def test_validate_date_times(start_date_time, end_date_time, error_type): + if error_type: + with pytest.raises(error_type): + validate_date_times(start_date_time, end_date_time) + else: + validate_date_times(start_date_time, end_date_time) + + +@pytest.mark.parametrize( + "emails, expect_error", + [ + (["test@test.com"], False), + (["test@test.com", "test@test.com.au"], False), + (["test@test.com", "test@test.com.au."], True), + (["#$&*@test.com"], True), + ], +) +def test_validate_emails(emails, expect_error): + if expect_error: + with pytest.raises(ToolExecutionError): + validate_emails(emails) + else: + validate_emails(emails) + + +@pytest.mark.parametrize( + "email, is_valid", + [ + ("test@test.com", True), + ("test@test", False), + ("test@test.com.au", True), + ("test@test.com.au.", False), + ], +) +def test_is_valid_email(email, is_valid): + assert is_valid_email(email) == is_valid + + +@pytest.mark.parametrize( + "input_date_time, expected_date_time", + [ + ("2021-01-01T10:00:00+07:00", "2021-01-01T10:00:00"), + ("2021-01-01T10:00:00-07:00", "2021-01-01T10:00:00"), + ("2021-01-01T10:00:00Z", "2021-01-01T10:00:00"), + ], +) +def test_remove_timezone_offset(input_date_time, expected_date_time): + assert remove_timezone_offset(input_date_time) == expected_date_time + + +@pytest.mark.parametrize( + "input_date_time, time_zone_offset, expected_date_time", + [ + # without existing offset + ("2021-01-01T10:00:00", "+07:00", "2021-01-01T10:00:00+07:00"), + ("2021-01-01T10:00:00", "-07:00", "2021-01-01T10:00:00-07:00"), + ("2021-01-01T10:00:00", "Z", "2021-01-01T10:00:00Z"), + # with existing offset + ("2021-01-01T10:00:00+07:00", "+04:00", "2021-01-01T10:00:00+04:00"), + ("2021-01-01T10:00:00-07:00", "-09:00", "2021-01-01T10:00:00-09:00"), + ("2021-01-01T10:00:00-07:00", "Z", "2021-01-01T10:00:00Z"), + ], +) +def test_replace_timezone_offset(input_date_time, time_zone_offset, expected_date_time): + assert replace_timezone_offset(input_date_time, time_zone_offset) == expected_date_time + + +@pytest.mark.parametrize( + "time_zone, expected_offset", + [ + ("Central Asia Standard Time", "+05:00"), # Windows timezone format + ("America/New_York", "-04:00"), # IANA timezone format + ("Not a valid timezone", "Z"), # Fallback to UTC + ], +) +def test_convert_timezone_to_offset(time_zone, expected_offset): + assert convert_timezone_to_offset(time_zone) == expected_offset diff --git a/toolkits/outlook_mail/.pre-commit-config.yaml b/toolkits/outlook_mail/.pre-commit-config.yaml new file mode 100644 index 00000000..0672d232 --- /dev/null +++ b/toolkits/outlook_mail/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/outlook_mail/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/outlook_mail/.ruff.toml b/toolkits/outlook_mail/.ruff.toml new file mode 100644 index 00000000..19364180 --- /dev/null +++ b/toolkits/outlook_mail/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/outlook_mail/LICENSE b/toolkits/outlook_mail/LICENSE new file mode 100644 index 00000000..dfbb8b76 --- /dev/null +++ b/toolkits/outlook_mail/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025, Arcade AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/outlook_mail/Makefile b/toolkits/outlook_mail/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/outlook_mail/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/outlook_mail/arcade_outlook_mail/__init__.py b/toolkits/outlook_mail/arcade_outlook_mail/__init__.py new file mode 100644 index 00000000..545f4d7c --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/__init__.py @@ -0,0 +1,24 @@ +from arcade_outlook_mail.tools import ( + create_and_send_email, + create_draft_email, + list_emails, + list_emails_by_property, + list_emails_in_folder, + reply_to_email, + send_draft_email, + update_draft_email, +) + +__all__ = [ + # Read + "list_emails", + "list_emails_by_property", + "list_emails_in_folder", + # Send + "create_and_send_email", + "send_draft_email", + "reply_to_email", + # Write + "create_draft_email", + "update_draft_email", +] diff --git a/toolkits/outlook_mail/arcade_outlook_mail/_utils.py b/toolkits/outlook_mail/arcade_outlook_mail/_utils.py new file mode 100644 index 00000000..3f9f2795 --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/_utils.py @@ -0,0 +1,120 @@ +from arcade_tdk import ToolContext +from msgraph.generated.models.message_collection_response import MessageCollectionResponse +from msgraph.generated.users.item.mail_folders.item.messages.messages_request_builder import ( + MessagesRequestBuilder as MailFolderMessagesRequestBuilder, +) +from msgraph.generated.users.item.messages.item.reply.reply_post_request_body import ( + ReplyPostRequestBody, +) +from msgraph.generated.users.item.messages.item.reply_all.reply_all_post_request_body import ( + ReplyAllPostRequestBody, +) +from msgraph.generated.users.item.messages.messages_request_builder import ( + MessagesRequestBuilder as UserMessagesRequestBuilder, +) + +from arcade_outlook_mail.client import get_client +from arcade_outlook_mail.constants import DEFAULT_MESSAGE_FIELDS +from arcade_outlook_mail.enums import ( + EmailFilterProperty, + FilterOperator, + ReplyType, +) + + +def remove_none_values(data: dict) -> dict: + """Remove all keys with None values from the dictionary.""" + return {k: v for k, v in data.items() if v is not None} + + +def _create_filter_expression( + property_: EmailFilterProperty | None = None, + operator: FilterOperator | None = None, + value: str | None = None, +) -> str | None: + if property_ and operator and value: + property_value = property_.value + operator_value = operator.value + + # Never use quotes around 'value' for booleans and numerics + value_quote = "'" + if value.lower() in ["true", "false"] or value.isdigit(): + value_quote = "" + + # Handle function operators (e.g., contains, startsWith, endsWith) + if operator.is_function(): + filter_expr = f"{operator_value}({property_value}, {value_quote}{value}{value_quote})" + else: + # Handle comparison operators (e.g., eq, ne, gt, ge, lt, le) + filter_expr = f"{property_value} {operator_value} {value_quote}{value}{value_quote}" + + if property_value == EmailFilterProperty.RECEIVED_DATE_TIME: + filter_expr = filter_expr + else: # Since receivedDateTime is in orderby, it must be in filter: https://learn.microsoft.com/en-us/graph/api/user-list-messages?view=graph-rest-1.0&tabs=http#optional-query-parameters + filter_expr = f"receivedDateTime ge 1900-01-01T00:00:00Z and {filter_expr}" + + return filter_expr + + return None + + +def prepare_list_emails_request_config( + limit: int, + property_: EmailFilterProperty | None = None, + operator: FilterOperator | None = None, + value: str | None = None, +) -> MailFolderMessagesRequestBuilder.MessagesRequestBuilderGetRequestConfiguration: + """Prepare a request configuration for listing emails.""" + limit = max(1, min(limit, 100)) # limit must be between 1 and 100 + + orderby = ["receivedDateTime DESC"] + filter_expr = _create_filter_expression(property_, operator, value) + + query_params = MailFolderMessagesRequestBuilder.MessagesRequestBuilderGetQueryParameters( + count=True, + select=DEFAULT_MESSAGE_FIELDS, + orderby=orderby, + filter=filter_expr, + top=limit, + ) + return MailFolderMessagesRequestBuilder.MessagesRequestBuilderGetRequestConfiguration( + query_parameters=query_params, + ) + + +async def fetch_emails( + message_builder: MailFolderMessagesRequestBuilder | UserMessagesRequestBuilder, + pagination_token: str | None = None, + request_config: MailFolderMessagesRequestBuilder.MessagesRequestBuilderGetRequestConfiguration + | None = None, +) -> MessageCollectionResponse: + """Fetch emails from the user's mailbox. + + Microsoft Graph Python SDK does not support pagination (as of 2025-04-17), + so we use raw URL for pagination if a pagination token is provided. + """ + if pagination_token: + return await message_builder.with_url(pagination_token).get() # type: ignore[return-value] + return await message_builder.get(request_configuration=request_config) # type: ignore[return-value, arg-type] + + +async def send_reply_email( + context: ToolContext, + message_id: str, + body: str, + reply_type: ReplyType, +) -> dict: + """Send a reply email to the sender or all recipients of an existing email.""" + client = get_client(context.get_auth_token_or_empty()) + + if reply_type == ReplyType.REPLY: + reply_request_body = ReplyPostRequestBody(comment=body) + await client.me.messages.by_message_id(message_id).reply.post(reply_request_body) + elif reply_type == ReplyType.REPLY_ALL: + reply_all_request_body = ReplyAllPostRequestBody(comment=body) + await client.me.messages.by_message_id(message_id).reply_all.post(reply_all_request_body) + + return { + "success": True, + "message": "Email sent successfully", + } diff --git a/toolkits/outlook_mail/arcade_outlook_mail/client.py b/toolkits/outlook_mail/arcade_outlook_mail/client.py new file mode 100644 index 00000000..e11d257a --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/client.py @@ -0,0 +1,26 @@ +import datetime +from typing import Any + +from azure.core.credentials import AccessToken, TokenCredential +from msgraph import GraphServiceClient + +DEFAULT_SCOPE = "https://graph.microsoft.com/.default" + + +class StaticTokenCredential(TokenCredential): + """Implementation of TokenCredential protocol to be provided to the MSGraph SDK client""" + + def __init__(self, token: str): + self._token = token + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + # An expiration is required by MSGraph SDK. Set to 1 hour from now. + expires_on = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) + 3600 + return AccessToken(self._token, expires_on) + + +def get_client(token: str) -> GraphServiceClient: + """Create and return a MSGraph SDK client, given the provided token.""" + token_credential = StaticTokenCredential(token) + + return GraphServiceClient(token_credential, scopes=[DEFAULT_SCOPE]) diff --git a/toolkits/outlook_mail/arcade_outlook_mail/constants.py b/toolkits/outlook_mail/arcade_outlook_mail/constants.py new file mode 100644 index 00000000..6daef4d6 --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/constants.py @@ -0,0 +1,18 @@ +DEFAULT_MESSAGE_FIELDS = [ + "bccRecipients", + "body", + "ccRecipients", + "conversationId", + "conversationIndex", + "flag", + "from", + "hasAttachments", + "importance", + "isDraft", + "isRead", + "receivedDateTime", + "replyTo", + "subject", + "toRecipients", + "webLink", +] diff --git a/toolkits/outlook_mail/arcade_outlook_mail/enums.py b/toolkits/outlook_mail/arcade_outlook_mail/enums.py new file mode 100644 index 00000000..ce68162f --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/enums.py @@ -0,0 +1,65 @@ +from enum import Enum + + +class WellKnownFolderNames(str, Enum): + """Well-known folder names that are created for users by default. + Instead of using the ID of these folders, you can use the well-known folder names. + For a list of all well-known folder names, see: https://learn.microsoft.com/en-us/graph/api/resources/mailfolder?view=graph-rest-1.0 + """ + + DELETED_ITEMS = "deleteditems" + DRAFTS = "drafts" + INBOX = "inbox" + JUNK_EMAIL = "junkemail" + SENT_ITEMS = "sentitems" + STARRED = "starred" + TODO = "tasks" + + +class ReplyType(str, Enum): + """The type of reply to send to an email.""" + + REPLY = "reply" + REPLY_ALL = "reply_all" + + +class EmailFilterProperty(str, Enum): + """The property to filter the emails by.""" + + # Basic properties + SUBJECT = "subject" + CONVERSATION_ID = "conversationId" + RECEIVED_DATE_TIME = "receivedDateTime" + SENDER = "sender/emailAddress/address" + + +class FilterOperator(str, Enum): + """The operator to use for the filter. + + For a full list of possible operators, see: https://learn.microsoft.com/en-us/graph/filter-query-parameter?tabs=http#operators-and-functions-supported-in-filter-expressions + """ + + # Equality operators + EQUAL = "eq" # example: $filter=conversationId eq 'hello' + NOT_EQUAL = "ne" # example: $filter=subject ne 'hello' + + # Relational operators + GREATER_THAN = "gt" # example: $filter=receivedDateTime gt 2024-01-01 + GREATER_THAN_OR_EQUAL_TO = "ge" # example: $filter=receivedDateTime ge 2024-01-01 + LESS_THAN = "lt" # example: $filter=receivedDateTime lt 2024-01-01 + LESS_THAN_OR_EQUAL_TO = "le" # example: $filter=receivedDateTime le 2024-01-01 + + # Functions + STARTS_WITH = "startsWith" # example: $filter=startsWith(subject, 'hello') + ENDS_WITH = "endsWith" # example: $filter=endsWith(subject, 'hello') + CONTAINS = "contains" # example: $filter=contains(subject, 'hello') + + def is_operator(self) -> bool: + """Check if the operator is a comparison operator.""" + operators = [self.EQUAL, self.NOT_EQUAL] + return self in operators + + def is_function(self) -> bool: + """Check if the operator is a function.""" + functions = [self.STARTS_WITH, self.ENDS_WITH, self.CONTAINS] + return self in functions diff --git a/toolkits/outlook_mail/arcade_outlook_mail/message.py b/toolkits/outlook_mail/arcade_outlook_mail/message.py new file mode 100644 index 00000000..b7e2ec7f --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/message.py @@ -0,0 +1,218 @@ +import re +from dataclasses import dataclass, field +from typing import Any + +from bs4 import BeautifulSoup +from msgraph.generated.models.body_type import BodyType +from msgraph.generated.models.email_address import EmailAddress +from msgraph.generated.models.item_body import ItemBody +from msgraph.generated.models.message import Message as GraphMessage +from msgraph.generated.models.recipient import Recipient as GraphRecipient + + +@dataclass +class Recipient: + """A recipient of an email message.""" + + email_address: str = "" + name: str = "" + + @classmethod + def from_sdk(cls, recipient: GraphRecipient) -> "Recipient": + """Convert a Microsoft Graph SDK Recipient object to a Recipient dataclass.""" + address = ( + recipient.email_address.address + if recipient and recipient.email_address and recipient.email_address.address + else "" + ) + name = ( + recipient.email_address.name + if recipient and recipient.email_address and recipient.email_address.name + else "" + ) + return cls(email_address=address, name=name) + + def to_dict(self) -> dict[str, str]: + return {"email_address": self.email_address, "name": self.name} + + def to_sdk(self) -> GraphRecipient: + """Converts the Recipient dataclass to a Microsoft Graph SDK Recipient object.""" + recipient = GraphRecipient() + email_address = EmailAddress() + email_address.address = self.email_address + email_address.name = self.name + recipient.email_address = email_address + return recipient + + +@dataclass +class Message: + """An email message in Outlook.""" + + bcc_recipients: list[Recipient] = field(default_factory=list) + cc_recipients: list[Recipient] = field(default_factory=list) + reply_to: list[Recipient] = field(default_factory=list) + to_recipients: list[Recipient] = field(default_factory=list) + from_: Recipient = field(default_factory=Recipient) + subject: str = "" + body: str = "" + conversation_id: str = "" + conversation_index: str = "" + flag: dict[str, str] = field(default_factory=dict) + has_attachments: bool = False + importance: str = "" + is_read: bool = False + received_date_time: str = "" + web_link: str = "" + is_draft: bool = True + message_id: str = "" # The unique identifier of the email message. Read-only. + + @staticmethod + def _safe_str(value: Any) -> str: + if not value: + return "" + if isinstance(value, bytes | bytearray): + return value.decode("utf-8", errors="ignore") + return str(value) + + @staticmethod + def _safe_bool(value: Any) -> bool: + return bool(value) + + @staticmethod + def _parse_body(mime: str) -> str: + if not mime: + return "" + soup = BeautifulSoup(mime, "html.parser") + text = soup.get_text(separator=" ") + # Replace multiple newlines with a single newline + text = re.sub(r"\n+", "\n", text) + # Replace multiple spaces with a single space + text = re.sub(r"\s+", " ", text) + # Remove leading/trailing whitespace from each line + text = "\n".join(line.strip() for line in text.split("\n")) + + return text + + @staticmethod + def _parse_importance(value: Any) -> str: + return value.value if getattr(value, "value", None) else "" + + @staticmethod + def _parse_flag(flag: Any) -> dict[str, str]: + if not flag: + return {"flag_status": "", "due_date_time": ""} + status = flag.flag_status.value if getattr(flag, "flag_status", None) else "" + due = "" + if getattr(flag, "due_date_time", None) and getattr(flag.due_date_time, "date_time", None): + due = Message._safe_str(flag.due_date_time.date_time) + return {"flag_status": status, "due_date_time": due} + + @classmethod + def from_sdk(cls, msg: GraphMessage) -> "Message": + """Convert a Microsoft Graph SDK Message object to a Message dataclass.""" + text = cls._parse_body(msg.body.content if msg.body and msg.body.content else "") + return cls( + bcc_recipients=[ + Recipient.from_sdk(recipient) for recipient in msg.bcc_recipients or [] + ], + cc_recipients=[Recipient.from_sdk(recipient) for recipient in msg.cc_recipients or []], + reply_to=[Recipient.from_sdk(recipient) for recipient in msg.reply_to or []], + to_recipients=[Recipient.from_sdk(recipient) for recipient in msg.to_recipients or []], + from_=Recipient.from_sdk(msg.from_) if msg.from_ else Recipient(), + subject=cls._safe_str(msg.subject), + body=text, + conversation_id=cls._safe_str(msg.conversation_id), + conversation_index=( + msg.conversation_index.decode("utf-8", errors="ignore") + if isinstance(msg.conversation_index, bytes | bytearray) + else cls._safe_str(msg.conversation_index) + ), + flag=cls._parse_flag(msg.flag), + has_attachments=cls._safe_bool(msg.has_attachments), + importance=cls._parse_importance(msg.importance), + is_read=cls._safe_bool(msg.is_read), + received_date_time=( + msg.received_date_time.isoformat() if msg.received_date_time else "" + ), + web_link=cls._safe_str(msg.web_link), + is_draft=cls._safe_bool(msg.is_draft), + message_id=cls._safe_str(msg.id), + ) + + def to_sdk(self) -> GraphMessage: + """Converts the Message dataclass to a Microsoft Graph SDK Message object.""" + sdk_msg = GraphMessage() + sdk_msg.subject = self.subject + body_obj = ItemBody() + body_obj.content = self.body + body_obj.content_type = BodyType.Text + sdk_msg.body = body_obj + sdk_msg.is_draft = self.is_draft + sdk_msg.to_recipients = [r.to_sdk() for r in self.to_recipients] + sdk_msg.cc_recipients = [r.to_sdk() for r in self.cc_recipients] + sdk_msg.bcc_recipients = [r.to_sdk() for r in self.bcc_recipients] + sdk_msg.reply_to = [r.to_sdk() for r in self.reply_to] + + return sdk_msg + + def to_dict(self) -> dict[str, Any]: + """Converts the Message dataclass to a dictionary.""" + return { + "bcc_recipients": [recipient.to_dict() for recipient in self.bcc_recipients], + "cc_recipients": [recipient.to_dict() for recipient in self.cc_recipients], + "reply_to": [recipient.to_dict() for recipient in self.reply_to], + "to_recipients": [recipient.to_dict() for recipient in self.to_recipients], + "from": self.from_.to_dict(), + "subject": self.subject, + "body": self.body, + "conversation_id": self.conversation_id, + "conversation_index": self.conversation_index, + "flag": self.flag, + "has_attachments": self.has_attachments, + "importance": self.importance, + "is_read": self.is_read, + "received_date_time": self.received_date_time, + "web_link": self.web_link, + "is_draft": self.is_draft, + "message_id": self.message_id, + } + + def update_recipient_lists( + self, + to_add: list[str] | None = None, + to_remove: list[str] | None = None, + cc_add: list[str] | None = None, + cc_remove: list[str] | None = None, + bcc_add: list[str] | None = None, + bcc_remove: list[str] | None = None, + ) -> None: + """Update each recipient list of the message. + + This function updates the recipient lists of the message by first adding new recipients + and then removing existing recipients. Therefore, if an email address is both + added and removed, then it will not be included in the returned list. + """ + for attr, add_emails_input, remove_emails_input in ( + ("to_recipients", to_add, to_remove), + ("cc_recipients", cc_add, cc_remove), + ("bcc_recipients", bcc_add, bcc_remove), + ): + current_recipients = getattr(self, attr) or [] + # Add recipients + existing_emails = {r.email_address.lower() for r in current_recipients} + new_additions = [ + Recipient(email_address=email) + for email in (add_emails_input or []) + if email.lower() not in existing_emails + ] + # Remove recipients + updated_list = current_recipients + new_additions + remove_emails = {email.lower() for email in (remove_emails_input or [])} + updated_list = [ + recipient + for recipient in updated_list + if recipient.email_address.lower() not in remove_emails + ] + # Update the message's attribute with the new list + setattr(self, attr, updated_list) diff --git a/toolkits/outlook_mail/arcade_outlook_mail/tools/__init__.py b/toolkits/outlook_mail/arcade_outlook_mail/tools/__init__.py new file mode 100644 index 00000000..0ae8d773 --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/tools/__init__.py @@ -0,0 +1,28 @@ +from arcade_outlook_mail.tools.read import ( + list_emails, + list_emails_by_property, + list_emails_in_folder, +) +from arcade_outlook_mail.tools.send import ( + create_and_send_email, + reply_to_email, + send_draft_email, +) +from arcade_outlook_mail.tools.write import ( + create_draft_email, + update_draft_email, +) + +__all__ = [ + # Read + "list_emails", + "list_emails_by_property", + "list_emails_in_folder", + # Send + "create_and_send_email", + "reply_to_email", + "send_draft_email", + # Write + "create_draft_email", + "update_draft_email", +] diff --git a/toolkits/outlook_mail/arcade_outlook_mail/tools/read.py b/toolkits/outlook_mail/arcade_outlook_mail/tools/read.py new file mode 100644 index 00000000..e63c52e9 --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/tools/read.py @@ -0,0 +1,122 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Microsoft +from arcade_tdk.errors import ToolExecutionError + +from arcade_outlook_mail._utils import ( + fetch_emails, + prepare_list_emails_request_config, + remove_none_values, +) +from arcade_outlook_mail.client import get_client +from arcade_outlook_mail.enums import ( + EmailFilterProperty, + FilterOperator, + WellKnownFolderNames, +) +from arcade_outlook_mail.message import Message + + +@tool(requires_auth=Microsoft(scopes=["Mail.Read"])) +async def list_emails( + context: ToolContext, + limit: Annotated[int, "The number of messages to return. Max is 100. Defaults to 5."] = 5, + pagination_token: Annotated[ + str | None, "The pagination token to continue a previous request" + ] = None, +) -> Annotated[dict, "A dictionary containing a list of emails"]: + """List emails in the user's mailbox across all folders. + + Since this tool lists email across all folders, it may return sent items, drafts, + and other items that are not in the inbox. + """ + client = get_client(context.get_auth_token_or_empty()) + request_config = prepare_list_emails_request_config(limit) + message_builder = client.me.messages + + response = await fetch_emails(message_builder, pagination_token, request_config) + messages = [Message.from_sdk(msg).to_dict() for msg in response.value or []] + pagination_token = response.odata_next_link + + result = { + "messages": messages, + "num_messages": len(messages), + "pagination_token": pagination_token, + } + result = remove_none_values(result) + return result + + +@tool(requires_auth=Microsoft(scopes=["Mail.Read"])) +async def list_emails_in_folder( + context: ToolContext, + well_known_folder_name: Annotated[ + WellKnownFolderNames | None, + "The name of the folder to list emails from. Defaults to None.", + ] = None, + folder_id: Annotated[ + str | None, + "The ID of the folder to list emails from if the folder is not a well-known folder. " + "Defaults to None.", + ] = None, + limit: Annotated[int, "The number of messages to return. Max is 100. Defaults to 5."] = 5, + pagination_token: Annotated[ + str | None, "The pagination token to continue a previous request" + ] = None, +) -> Annotated[ + dict, "A dictionary containing a list of emails and a pagination token, if applicable" +]: + """List the user's emails in the specified folder. + + Exactly one of `well_known_folder_name` or `folder_id` MUST be provided. + """ + if not (bool(well_known_folder_name) ^ bool(folder_id)): + raise ToolExecutionError( + message="Exactly one of `well_known_folder_name` or `folder_id` must be provided." + ) + folder_name = well_known_folder_name.value if well_known_folder_name else folder_id + client = get_client(context.get_auth_token_or_empty()) + request_config = prepare_list_emails_request_config(limit) + message_builder = client.me.mail_folders.by_mail_folder_id(folder_name).messages # type: ignore [arg-type] + + response = await fetch_emails(message_builder, pagination_token, request_config) + messages = [Message.from_sdk(msg).to_dict() for msg in response.value or []] + pagination_token = response.odata_next_link + + result = { + "messages": messages, + "num_messages": len(messages), + "pagination_token": pagination_token, + } + result = remove_none_values(result) + return result + + +@tool(requires_auth=Microsoft(scopes=["Mail.Read"])) +async def list_emails_by_property( + context: ToolContext, + property: Annotated[EmailFilterProperty, "The property to filter the emails by."], # noqa: A002 + operator: Annotated[FilterOperator, "The operator to use for the filter."], + value: Annotated[str, "The value to filter the emails by"], + limit: Annotated[int, "The number of messages to return. Max is 100. Defaults to 5."] = 5, + pagination_token: Annotated[ + str | None, "The pagination token to continue a previous request" + ] = None, +) -> Annotated[dict, "A dictionary containing a list of emails"]: + """List emails in the user's mailbox across all folders filtering by a property.""" + client = get_client(context.get_auth_token_or_empty()) + request_config = prepare_list_emails_request_config(limit, property, operator, value) + message_builder = client.me.messages + + response = await fetch_emails(message_builder, pagination_token, request_config) + messages = [Message.from_sdk(msg).to_dict() for msg in response.value or []] + pagination_token = response.odata_next_link + + result = { + "messages": messages, + "num_messages": len(messages), + "pagination_token": pagination_token, + } + result = remove_none_values(result) + return result diff --git a/toolkits/outlook_mail/arcade_outlook_mail/tools/send.py b/toolkits/outlook_mail/arcade_outlook_mail/tools/send.py new file mode 100644 index 00000000..22dacd78 --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/tools/send.py @@ -0,0 +1,94 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Microsoft +from msgraph.generated.users.item.send_mail.send_mail_post_request_body import ( + SendMailPostRequestBody, +) + +from arcade_outlook_mail._utils import send_reply_email +from arcade_outlook_mail.client import get_client +from arcade_outlook_mail.enums import ReplyType +from arcade_outlook_mail.message import Message, Recipient + + +@tool(requires_auth=Microsoft(scopes=["Mail.Send"])) +async def create_and_send_email( + context: ToolContext, + subject: Annotated[str, "The subject of the email to create"], + body: Annotated[str, "The body of the email to create"], + to_recipients: Annotated[ + list[str], "The email addresses that will be the recipients of the email" + ], + cc_recipients: Annotated[ + list[str] | None, "The email addresses that will be the CC recipients of the email." + ] = None, + bcc_recipients: Annotated[ + list[str] | None, + "The email addresses that will be the BCC recipients of the email.", + ] = None, +) -> Annotated[dict, "A dictionary containing the created email details"]: + """Create and immediately send a new email in Outlook to the specified recipients""" + client = get_client(context.get_auth_token_or_empty()) + message = Message( + subject=subject, + body=body, + to_recipients=[Recipient(email_address=email) for email in to_recipients], + cc_recipients=[Recipient(email_address=email) for email in cc_recipients or []], + bcc_recipients=[Recipient(email_address=email) for email in bcc_recipients or []], + ).to_sdk() + + send_mail_request_body = SendMailPostRequestBody( + message=message, + save_to_sent_items=True, + ) + + await client.me.send_mail.post(send_mail_request_body) + + return { + "success": True, + "message": "Email sent successfully", + } + + +@tool(requires_auth=Microsoft(scopes=["Mail.Send"])) +async def send_draft_email( + context: ToolContext, + message_id: Annotated[str, "The ID of the draft email to send"], +) -> Annotated[dict, "A dictionary containing the sent email details"]: + """Send an existing draft email in Outlook + + This tool can send any un-sent email: + - draft + - reply-draft + - reply-all draft + - forward draft + """ + client = get_client(context.get_auth_token_or_empty()) + + await client.me.messages.by_message_id(message_id).send.post() + + return { + "success": True, + "message": "Email sent successfully", + } + + +@tool(requires_auth=Microsoft(scopes=["Mail.Send"])) +async def reply_to_email( + context: ToolContext, + message_id: Annotated[str, "The ID of the email to reply to"], + body: Annotated[str, "The body of the reply to the email"], + reply_type: Annotated[ + ReplyType, + f"Specify {ReplyType.REPLY} to reply only to the sender or " + f"{ReplyType.REPLY_ALL} to reply to all recipients. " + f"Defaults to {ReplyType.REPLY}.", + ] = ReplyType.REPLY, +) -> Annotated[dict, "A dictionary containing the sent email details"]: + """Reply to an existing email in Outlook. + + Use this tool to reply to the sender or all recipients of the email. + Specify the reply_type to determine the scope of the reply. + """ + return await send_reply_email(context, message_id, body, reply_type) diff --git a/toolkits/outlook_mail/arcade_outlook_mail/tools/write.py b/toolkits/outlook_mail/arcade_outlook_mail/tools/write.py new file mode 100644 index 00000000..f80e8f60 --- /dev/null +++ b/toolkits/outlook_mail/arcade_outlook_mail/tools/write.py @@ -0,0 +1,115 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Microsoft +from arcade_tdk.errors import ToolExecutionError + +from arcade_outlook_mail.client import get_client +from arcade_outlook_mail.message import Message, Recipient + + +@tool(requires_auth=Microsoft(scopes=["Mail.ReadWrite"])) +async def create_draft_email( + context: ToolContext, + subject: Annotated[str, "The subject of the draft email to create"], + body: Annotated[str, "The body of the draft email to create"], + to_recipients: Annotated[ + list[str], "The email addresses that will be the recipients of the draft email" + ], + cc_recipients: Annotated[ + list[str] | None, + "The email addresses that will be the CC recipients of the draft email.", + ] = None, + bcc_recipients: Annotated[ + list[str] | None, + "The email addresses that will be the BCC recipients of the draft email.", + ] = None, +) -> Annotated[dict, "A dictionary containing the created email details"]: + """Compose a new draft email in Outlook""" + client = get_client(context.get_auth_token_or_empty()) + + message = Message( + subject=subject, + body=body, + to_recipients=[Recipient(email_address=email) for email in to_recipients], + cc_recipients=[Recipient(email_address=email) for email in cc_recipients or []], + bcc_recipients=[Recipient(email_address=email) for email in bcc_recipients or []], + is_draft=True, + ).to_sdk() + + response = await client.me.messages.post(message) + draft_message = Message.from_sdk(response).to_dict() # type: ignore [arg-type] + + return draft_message + + +@tool(requires_auth=Microsoft(scopes=["Mail.ReadWrite"])) +async def update_draft_email( + context: ToolContext, + message_id: Annotated[str, "The ID of the draft email to update"], + subject: Annotated[ + str | None, + "The new subject of the draft email. If provided, the existing subject will be overwritten", + ] = None, + body: Annotated[ + str | None, + "The new body of the draft email. If provided, the existing body will be overwritten", + ] = None, + to_add: Annotated[list[str] | None, "Email addresses to add as 'To' recipients."] = None, + to_remove: Annotated[ + list[str] | None, + "Email addresses to remove from the current 'To' recipients.", + ] = None, + cc_add: Annotated[ + list[str] | None, + "Email addresses to add as 'CC' recipients.", + ] = None, + cc_remove: Annotated[ + list[str] | None, + "Email addresses to remove from the current 'CC' recipients.", + ] = None, + bcc_add: Annotated[ + list[str] | None, + "Email addresses to add as 'BCC' recipients.", + ] = None, + bcc_remove: Annotated[ + list[str] | None, + "Email addresses to remove from the current 'BCC' recipients.", + ] = None, +) -> Annotated[dict, "A dictionary containing the updated email details"]: + """Update an existing draft email in Outlook. + + This tool overwrites the subject and body of a draft email (if provided), + and modifies its recipient lists by selectively adding or removing email addresses. + + This tool can update any un-sent email: + - draft + - reply-draft + - reply-all draft + - forward draft + """ + client = get_client(context.get_auth_token_or_empty()) + + # Get the draft email + draft_email_sdk = await client.me.messages.by_message_id(message_id).get() + + if draft_email_sdk is None: + raise ToolExecutionError(message=f"The draft email with ID {message_id} was not found.") + + # Update the draft email + draft_email = Message.from_sdk(draft_email_sdk) + draft_email.subject = subject if subject else draft_email.subject + draft_email.body = body if body else draft_email.body or "" + draft_email.update_recipient_lists( + to_add=to_add, + to_remove=to_remove, + cc_add=cc_add, + cc_remove=cc_remove, + bcc_add=bcc_add, + bcc_remove=bcc_remove, + ) + updated_draft_email = await client.me.messages.by_message_id(message_id).patch( + draft_email.to_sdk() + ) + + return Message.from_sdk(updated_draft_email).to_dict() # type: ignore [arg-type] diff --git a/toolkits/outlook_mail/evals/additional_messages.py b/toolkits/outlook_mail/evals/additional_messages.py new file mode 100644 index 00000000..809a21ec --- /dev/null +++ b/toolkits/outlook_mail/evals/additional_messages.py @@ -0,0 +1,83 @@ +update_draft_email_additional_messages = [ + {"role": "system", "content": "Today is 2025-04-22, Tuesday."}, + { + "role": "user", + "content": '"create a new draft email with subject \'Hello friends\' and body "\n"\'I\'ve gathered you all here to celebrate the launch of the new Arcade platform."\n "address it to e@arcade.dev and z@arcade.dev. also carbon copy to j@arcade.dev, "\n"f@arcade.dev, k@arcade.dev and finally to m@arcade.dev. also bcc to r@arcade.dev"', + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_lKw4S01FGe03oZeuW25roepy", + "type": "function", + "function": { + "name": "Microsoft_CreateDraftEmail", + "arguments": '{"subject":"Hello friends","body":"I\'ve gathered you all here to celebrate the launch of the new Arcade platform.","to_recipients":["e@arcade.dev","z@arcade.dev"],"cc_recipients":["j@arcade.dev","f@arcade.dev","k@arcade.dev","m@arcade.dev"],"bcc_recipients":["r@arcade.dev"]}', + }, + } + ], + }, + { + "role": "tool", + "content": '{"bcc_recipients":[{"email_address":"r@arcade.dev","name":"r@arcade.dev"}],"body":"I\'ve gathered you all here to celebrate the launch of the new Arcade platform.","cc_recipients":[{"email_address":"j@arcade.dev","name":"j@arcade.dev"},{"email_address":"f@arcade.dev","name":"f@arcade.dev"},{"email_address":"k@arcade.dev","name":"k@arcade.dev"},{"email_address":"m@arcade.dev","name":"m@arcade.dev"}],"conversation_id":"AQQkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoAEAAskq2oM-moTbt3gDT_yK0e","conversation_index":"AQHbs6c0LJKtqDP5qE27d4A0/sitHg==","flag":{"due_date_time":"","flag_status":"notFlagged"},"from":{"email_address":"","name":""},"has_attachments":false,"importance":"normal","is_draft":true,"is_read":true,"message_id":"AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4_x_WeUwAAAIBDwAAAFuxokOLZRtDncM4_x_WeUwAAAAC-dpvAAAA","received_date_time":"2025-04-22T16:54:25+00:00","reply_to":[],"subject":"Hello friends","to_recipients":[{"email_address":"e@arcade.dev","name":"e@arcade.dev"},{"email_address":"z@arcade.dev","name":"z@arcade.dev"}],"web_link":"https://outlook.live.com/owa/?ItemID=AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4%2Bx%2BWeUwAAAIBDwAAAFuxokOLZRtDncM4%2Bx%2BWeUwAAAAC%2FdpvAAAA\\u0026exvsurl=1\\u0026viewmodel=ReadMessageItem"}', + "tool_call_id": "call_lKw4S01FGe03oZeuW25roepy", + "name": "Microsoft_CreateDraftEmail", + }, + { + "role": "assistant", + "content": 'I have created a draft email with the subject "Hello friends" addressed to the specified recipients. You can view and edit the draft through [this link](https://outlook.live.com/owa/?ItemID=AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4%2Bx%2BWeUwAAAIBDwAAAFuxokOLZRtDncM4%2Bx%2BWeUwAAAAC%2FdpvAAAA&exvsurl=1&viewmodel=ReadMessageItem).', + }, +] + +list_emails_with_pagination_token_additional_messages = [ + {"role": "system", "content": "Today is 2025-04-21, Monday."}, + {"role": "user", "content": "get one email"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_jACvc3Gl1WHkqWgI8gdsIt0G", + "type": "function", + "function": {"name": "Microsoft_ListEmails", "arguments": '{"limit":1}'}, + } + ], + }, + { + "role": "tool", + "content": '{"messages":[{"bcc_recipients":[],"body":"Microsoft account New app(s) have access to your data Arcade","cc_recipients":["e@arcade.dev],"conversation_id":"AQQkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoAEABOD15A17tWSaVHkmjhko1R","conversation_index":"AQHbsuDd Tg9eQNe7VkmlR5Jo4ZKNUQ==","flag":{"due_date_time":"","flag_status":"notFlagged"},"from":{"email_address":"account-security@accountprotect ion.microsoft.com","name":"Microsoft account team"},"has_attachments":false,"importance":"normal","is_draft":false,"is_read":true,"message_id":"AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4_x_WeUwAAAIBDAAAAFuxokOLZRtDncM4_x_WeUwAAAABc_ezAAAA","received_date_time":"2025-04-21T17:14:39+00:00", "reply_to":[],"subject":"New app(s) connected to your Microsoft account","to_recipients":[{"email_address":"ericarcade@outlook.com","name":"ericarcade@outlook.com"}],"web_link":"https://outlook.live.com/owa/?ItemID=AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4%2Bx%2BWeUwAAAIBDAAAAFuxokOLZRtDncM4%2Bx%2BWeUwAAAABc%2BezAAAA\\u0026exvsurl=1\\u0026viewmodel=ReadMessageItem"}],"num_messages":1,"pagination_token":"https://graph.microsoft.com/v1.0/me/messages?%24count=true&%24orderby=receivedDateTime+DESC&%24select=bccRecipients%2cbody%2cccRecipients%2cconversationId%2cconversationIndex%2cflag%2cfrom%2chasAttachments%2cimportance%2cisDraft%2cisRead%2creceivedDateTime%2creplyTo%2csubject%2ctoRecipients%2cwebLink&%24top=1&%24skip=1"}', + "tool_call_id": "call_jACvc3Gl1WHkqWgI8gdsIt0G", + "name": "Microsoft_ListEmails", + }, + { + "role": "assistant", + "content": "Here is the most recent email you received:\n\n- **From:** Microsoft account team (account-security@accountprotection.microsoft.com)\n- **To:** e@outlook.com\n- **Subject:** New app(s) connected to your Microsoft account\n- **Received Date:** April 21, 2025\n- **Body:**\n ```\n Microsoft account\n\n New app(s) have access to your data Arcade connected to the Microsoft account *@outlook.com.```\n- **Link to email:** [Read in Outlook](https://outlook.live.com/owa/?ItemID=AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4%2Bx%2BWeUwAAAIBDAAAAFuxokOLZRtDncM4%2Bx%2BWeUwAAAABc%2BezAAAA&exvsurl=1&viewmodel=ReadMessageItem)", + }, +] + +list_emails_with_pagination_token_additional_messages = [ + {"role": "system", "content": "Today is 2025-04-21, Monday."}, + {"role": "user", "content": "get one email"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_jACvc3Gl1WHkqWgI8gdsIt0G", + "type": "function", + "function": {"name": "Microsoft_ListEmails", "arguments": '{"limit":1}'}, + } + ], + }, + { + "role": "tool", + "content": '{"messages":[{"bcc_recipients":[],"body":"Microsoft account New app(s) have access to your data Arcade","cc_recipients":[],"conversation_id":"AQQkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoAEABOD15A17tWSaVHkmjhko1R","conversation_index":"AQHbsuDdTg9eQNe7VkmlR5Jo4ZKNUQ==","flag":{"due_date_time":"","flag_status":"notFlagged"},"from":{"email_address":"account-security-noreply@accountprotection.microsoft.com","name":"Microsoft account team"},"has_attachments":false,"importance":"normal","is_draft":false,"is_read":true,"message_id":"AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4_x_WeUwAAAIBDAAAAFuxokOLZRtDncM4_x_WeUwAAAABc_ezAAAA","received_date_time":"2025-04-21T17:14:39+00:00", "reply_to":[],"subject":"New app(s) connected to your Microsoft account","to_recipients":[{"email_address":"ericarcade@outlook.com","name":"ericarcade@outlook.com"}],"web_link":"https://outlook.live.com/owa/?I temID=AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4%2Bx%2BWeUwAAAIBDAAAAFuxokOLZRtDncM4%2Bx%2BWeUwAAAA Bc%2BezAAAA\\u0026exvsurl=1\\u0026viewmodel=ReadMessageItem"}],"num_messages":1,"pagination_token":"https://graph.microsoft.com/v1.0/me/messages?%24count=true&%24orderby=receivedDateTime+DESC&%24select=bccRecipients%2cbody%2cccRecipients%2cconversationId%2cconversationIndex%2cflag%2cfrom%2chasAttachments%2cimportance%2cisDraft%2cisRead%2creceivedDateTime%2creplyTo%2csubject%2ctoRecipients%2cwebLink&%24top=1&%24skip=1"}', + "tool_call_id": "call_jACvc3Gl1WHkqWgI8gdsIt0G", + "name": "Microsoft_ListEmails", + }, + { + "role": "assistant", + "content": "Here is the most recent email you received:\n\n- **From:** Microsoft account team (account-security-noreply@accountprotection.microsoft.com)\n- **To:** e@outlook.com\n- **Subject:** New app(s) connected to your Microsoft account\n- **Received Date:** April 21, 2025\n- **Body:**\n ```\n Microsoft account\n\n New app(s) have access to your data Arcade connected to the Microsoft account *@outlook.com.```\n- **Link to email:** [Read in Outlook](https://outlook.live.com/owa/?ItemID=AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4%2Bx%2BWeUwAAAIBDAAAAFuxokOLZRtDncM4%2Bx%2BWeUwAAAABc%2BezAAAA&exvsurl=1&viewmodel=ReadMessageItem)", + }, +] diff --git a/toolkits/outlook_mail/evals/eval_read.py b/toolkits/outlook_mail/evals/eval_read.py new file mode 100644 index 00000000..1864f7ce --- /dev/null +++ b/toolkits/outlook_mail/evals/eval_read.py @@ -0,0 +1,210 @@ +from datetime import timedelta + +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +from arcade_outlook_mail import ( + list_emails, + list_emails_by_property, + list_emails_in_folder, +) +from arcade_outlook_mail.enums import WellKnownFolderNames +from evals.additional_messages import ( + list_emails_with_pagination_token_additional_messages, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + + +catalog = ToolCatalog() +catalog.add_tool(list_emails, "OutlookMail") +catalog.add_tool(list_emails_in_folder, "OutlookMail") +catalog.add_tool(list_emails_by_property, "OutlookMail") + + +@tool_eval() +def outlook_mail_read_eval_suite() -> EvalSuite: + """Create an evaluation suite for Outlook Mail tools.""" + suite = EvalSuite( + name="Outlook Mail Tools Evaluation", + system_message=("You are an AI that has access to tools to send, read, and write emails."), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="List emails in mailbox", + user_message="get my five most recent emails", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails, + args={"limit": 5}, + ) + ], + critics=[ + BinaryCritic(critic_field="limit", weight=1.0), + ], + ) + + suite.add_case( + name="List emails in mailbox with pagination token", + user_message="get the next 3", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails, + args={ + "limit": 3, + "pagination_token": "https://graph.microsoft.com/v1.0/me/messages?%24count=true&%24orderby=receivedDateTime+DESC&%24select=bccRecipients%2cbody%2cccRecipients%2cconversationId%2cconversationIndex%2cflag%2cfrom%2chasAttachments%2cimportance%2cisDraft%2cisRead%2creceivedDateTime%2creplyTo%2csubject%2ctoRecipients%2cwebLink&%24top=1&%24skip=1", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="limit", weight=0.2), + BinaryCritic(critic_field="pagination_token", weight=0.8), + ], + additional_messages=list_emails_with_pagination_token_additional_messages, + ) + + suite.add_case( + name="List emails in well-known folder", + user_message="summarize my inbox", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails_in_folder, + args={ + "well_known_folder_name": WellKnownFolderNames.INBOX, + "folder_id": None, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="well_known_folder_name", weight=0.5), + BinaryCritic(critic_field="folder_id", weight=0.5), + ], + ) + + suite.add_case( + name="List emails in folder by id", + user_message="get 5 from folder AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoALgAAAyXxSd3UxTpCkDpGouEg0JMBAFuxokOLZRtDncM4", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails_in_folder, + args={ + "well_known_folder_name": None, + "folder_id": "AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoALgAAAyXxSd3UxTpCkDpGouEg0JMBAFuxokOLZRtDncM4", + "limit": 5, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="well_known_folder_name", weight=0.4), + BinaryCritic(critic_field="folder_id", weight=0.4), + BinaryCritic(critic_field="limit", weight=0.2), + ], + ) + + return suite + + +@tool_eval() +def outlook_mail_list_emails_by_property_eval_suite() -> EvalSuite: + """Create an evaluation suite for Outlook Mail tools.""" + suite = EvalSuite( + name="Outlook Mail Tools Evaluation", + system_message=("You are an AI that has access to tools to send, read, and write emails."), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="List emails by subject", + user_message="get all emails that talk about The Green Bottle", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails_by_property, + args={ + "property": "subject", + "operator": "contains", + "value": "The Green Bottle", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="property", weight=1 / 3), + BinaryCritic(critic_field="operator", weight=1 / 3), + BinaryCritic(critic_field="value", weight=1 / 3), + ], + ) + + suite.extend_case( + name="List emails by thread", + user_message="get all emails in my thread 1k2jh324h92f24krjb34mtb43kj4bk3tmn34b3k4nnm3tb34mntb34mntb3m4bt3mn4bt3mn4btmnb34tmnb3t4mnb==34tkjh", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails_by_property, + args={ + "property": "conversationId", + "operator": "eq", + "value": "1k2jh324h92f24krjb34mtb43kj4bk3tmn34b3k4nnm3tb34mntb34mntb3m4bt3mn4bt3mn4btmnb34tmnb3t4mnb==34tkjh", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="property", weight=1 / 3), + BinaryCritic(critic_field="operator", weight=1 / 3), + BinaryCritic(critic_field="value", weight=1 / 3), + ], + ) + + suite.extend_case( + name="List emails by date", + user_message="Today is May 1st 2025. Get all emails that are a year old or older", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails_by_property, + args={ + "property": "receivedDateTime", + "operator": "le", + "value": "2024-05-01T00:00:00Z", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="property", weight=1 / 3), + BinaryCritic(critic_field="operator", weight=1 / 3), + DatetimeCritic(critic_field="value", weight=1 / 3, tolerance=timedelta(days=1)), + ], + ) + + suite.extend_case( + name="List emails by sender", + user_message="get all of my correspondence with the folks over at arcade.dev", + expected_tool_calls=[ + ExpectedToolCall( + func=list_emails_by_property, + args={ + "property": "sender/emailAddress/address", + "operator": "contains", + "value": "arcade.dev", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="property", weight=1 / 3), + BinaryCritic(critic_field="operator", weight=1 / 3), + BinaryCritic(critic_field="value", weight=1 / 3), + ], + ) + + return suite diff --git a/toolkits/outlook_mail/evals/eval_send.py b/toolkits/outlook_mail/evals/eval_send.py new file mode 100644 index 00000000..c8eba7d7 --- /dev/null +++ b/toolkits/outlook_mail/evals/eval_send.py @@ -0,0 +1,127 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +from arcade_outlook_mail import ( + create_and_send_email, + reply_to_email, + send_draft_email, +) +from arcade_outlook_mail.enums import ReplyType +from evals.additional_messages import ( + list_emails_with_pagination_token_additional_messages, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + + +catalog = ToolCatalog() +catalog.add_tool(create_and_send_email, "OutlookMail") +catalog.add_tool(send_draft_email, "OutlookMail") +catalog.add_tool(reply_to_email, "OutlookMail") + + +@tool_eval() +def outlook_mail_send_eval_suite() -> EvalSuite: + """Create an evaluation suite for Outlook Mail tools.""" + suite = EvalSuite( + name="Outlook Mail Send Evaluation", + system_message=("You are an AI that has access to tools to send, read, and write emails."), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Create draft email", + user_message=( + "send an email to j@arcade.dev and e@arcade.dev. Title it 'Hello friends' and have it " + "say 'I've gathered you all here to celebrate the launch of the new Arcade platform.'" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_and_send_email, + args={ + "subject": "Hello friends", + "body": "I've gathered you all here to celebrate the launch of the new Arcade platform.", + "to_recipients": ["j@arcade.dev", "e@arcade.dev"], + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=0.3), + SimilarityCritic(critic_field="body", weight=0.3), + BinaryCritic(critic_field="to_recipients", weight=0.4), + ], + ) + + suite.add_case( + name="Update draft email", + user_message=( + "forward the draft AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4_x_WeUwAAAIBDwAAAFuxokOLZRtDncM4_x_WeUwAAAAC-dpvAAAA " + ), + expected_tool_calls=[ + ExpectedToolCall( + func=send_draft_email, + args={ + "message_id": "AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4_x_WeUwAAAIBDwAAAFuxokOLZRtDncM4_x_WeUwAAAAC-dpvAAAA", + }, + ) + ], + critics=[ + BinaryCritic(critic_field="message_id", weight=1), + ], + ) + + suite.add_case( + name="Reply all to email", + user_message=("Reply to everyone - 'sounds good to me'"), + expected_tool_calls=[ + ExpectedToolCall( + func=reply_to_email, + args={ + "message_id": "AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4_x_WeUwAAAIBDAAAAFuxokOLZRtDncM4_x_WeUwAAAABc_ezAAAA", + "body": "sounds good to me", + "reply_type": ReplyType.REPLY_ALL, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="message_id", weight=1 / 3), + SimilarityCritic(critic_field="body", weight=1 / 3), + BinaryCritic(critic_field="reply_type", weight=1 / 3), + ], + additional_messages=list_emails_with_pagination_token_additional_messages, + ) + + suite.add_case( + name="Reply to email", + user_message=("Reply to the account security team - 'sounds good to me'"), + expected_tool_calls=[ + ExpectedToolCall( + func=reply_to_email, + args={ + "message_id": "AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4_x_WeUwAAAIBDAAAAFuxokOLZRtDncM4_x_WeUwAAAABc_ezAAAA", + "body": "sounds good to me", + "reply_type": ReplyType.REPLY, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="message_id", weight=1 / 3), + SimilarityCritic(critic_field="body", weight=1 / 3), + BinaryCritic(critic_field="reply_type", weight=1 / 3), + ], + additional_messages=list_emails_with_pagination_token_additional_messages, + ) + + return suite diff --git a/toolkits/outlook_mail/evals/eval_write.py b/toolkits/outlook_mail/evals/eval_write.py new file mode 100644 index 00000000..d341f669 --- /dev/null +++ b/toolkits/outlook_mail/evals/eval_write.py @@ -0,0 +1,104 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +from arcade_outlook_mail import create_draft_email, update_draft_email +from evals.additional_messages import ( + update_draft_email_additional_messages, +) + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.9, + warn_threshold=0.95, +) + + +catalog = ToolCatalog() +catalog.add_tool(create_draft_email, "OutlookMail") +catalog.add_tool(update_draft_email, "OutlookMail") + + +@tool_eval() +def outlook_mail_write_eval_suite() -> EvalSuite: + """Create an evaluation suite for Outlook Mail tools.""" + suite = EvalSuite( + name="Outlook Mail Write Evaluation", + system_message=("You are an AI that has access to tools to send, read, and write emails."), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Create draft email", + user_message=( + "create a new draft email with subject 'Hello friends' and body " + "'I've gathered you all here to celebrate the launch of the new Arcade platform." + "address it to e@arcade.dev and z@arcade.dev. also carbon copy to j@arcade.dev, " + "f@arcade.dev, k@arcade.dev and finally to m@arcade.dev. also bcc to r@arcade.dev" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=create_draft_email, + args={ + "subject": "Hello friends", + "body": "I've gathered you all here to celebrate the launch of the new Arcade platform.", + "to_recipients": ["e@arcade.dev", "z@arcade.dev"], + "cc_recipients": [ + "j@arcade.dev", + "f@arcade.dev", + "k@arcade.dev", + "m@arcade.dev", + ], + "bcc_recipients": ["r@arcade.dev"], + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="subject", weight=0.2), + SimilarityCritic(critic_field="body", weight=0.2), + BinaryCritic(critic_field="to_recipients", weight=0.2), + BinaryCritic(critic_field="cc_recipients", weight=0.2), + BinaryCritic(critic_field="bcc_recipients", weight=0.2), + ], + ) + + suite.add_case( + name="Update draft email", + user_message=( + "oh wait i think i messed up on some emails. I meant 'z', not 'e'. " + "Also, I forgot to bcc y@arcade.dev. Also, replace the period with an " + "exclamation point since I want to convey excitement. Oh I almost forgot, " + "Don't cc anyone." + ), + expected_tool_calls=[ + ExpectedToolCall( + func=update_draft_email, + args={ + "message_id": "AQMkADAwATM0MDAAMi04Y2Y1LTQ3MTEALTAwAi0wMAoARgAAAyXxSd3UxTpCkDpGouEg0JMHAFuxokOLZRtDncM4_x_WeUwAAAIBDwAAAFuxokOLZRtDncM4_x_WeUwAAAAC-dpvAAAA", + "body": "I've gathered you all here to celebrate the launch of the new Arcade platform!", + "to_add": ["z@arcade.dev"], + "to_remove": ["e@arcade.dev"], + "cc_remove": ["j@arcade.dev", "f@arcade.dev", "k@arcade.dev", "m@arcade.dev"], + "bcc_add": ["y@arcade.dev"], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="message_id", weight=1 / 6), + BinaryCritic(critic_field="body", weight=1 / 6), + BinaryCritic(critic_field="to_add", weight=1 / 6), + BinaryCritic(critic_field="to_remove", weight=1 / 6), + BinaryCritic(critic_field="cc_remove", weight=1 / 6), + BinaryCritic(critic_field="bcc_add", weight=1 / 6), + ], + additional_messages=update_draft_email_additional_messages, + ) + + return suite diff --git a/toolkits/outlook_mail/pyproject.toml b/toolkits/outlook_mail/pyproject.toml new file mode 100644 index 00000000..765458b6 --- /dev/null +++ b/toolkits/outlook_mail/pyproject.toml @@ -0,0 +1,60 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_outlook_mail" +version = "1.0.0" +description = "Arcade.dev LLM tools for Outlook Mail" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "msgraph-sdk>=1.28.0,<2.0.0", + "beautifulsoup4>=4.10.0,<5.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<9.0.0", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.11.1,<4.0.0", + "pytest-asyncio>=0.24.0,<1.0.0", + "mypy>=1.5.1,<2.0.0", + "pre-commit>=3.4.0,<4.0.0", + "tox>=4.11.1,<5.0.0", + "ruff>=0.7.4,<1.0.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_outlook_mail/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_outlook_mail",] diff --git a/toolkits/outlook_mail/tests/__init__.py b/toolkits/outlook_mail/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkits/outlook_mail/tests/test_message.py b/toolkits/outlook_mail/tests/test_message.py new file mode 100644 index 00000000..1ac53e09 --- /dev/null +++ b/toolkits/outlook_mail/tests/test_message.py @@ -0,0 +1,249 @@ +import pytest +from msgraph.generated.models.email_address import EmailAddress as GraphEmailAddress +from msgraph.generated.models.message import Message as GraphMessage +from msgraph.generated.models.recipient import Recipient as GraphRecipient + +from arcade_outlook_mail.message import Message, Recipient + + +# Dummy classes to simulate SDK objects +class DummyBody: + def __init__(self, content): + self.content = content + + +class DummyFlagStatus: + def __init__(self, value): + self.value = value + + +class DummyImportance: + def __init__(self, value): + self.value = value + + +class DummyDueDateTime: + def __init__(self, date_time): + self.date_time = date_time + + +class DummyFlag: + def __init__(self, flag_status, due_date_time): + self.flag_status = DummyFlagStatus(flag_status) + self.due_date_time = DummyDueDateTime(due_date_time) + + +class DummyDateTime: + def __init__(self, date_str): + self.date_str = date_str + + def isoformat(self): + return self.date_str + + +def make_graph_recipient(rec_data): + recipient = GraphRecipient() + recipient.email_address = GraphEmailAddress() + recipient.email_address.address = rec_data["email_address"] + recipient.email_address.name = rec_data.get("name", "") + return recipient + + +@pytest.mark.parametrize( + "input_data, expected", + [ + ( + { + "body_content": "

Hello world

", + "subject": "Test subject", + "conversation_id": "conv-1", + "conversation_index": "conv-index", + "flag_status": "flagged", + "due_date_time": "2021-01-01T10:00:00", + "has_attachments": False, + "importance": "high", + "is_read": True, + "received_date_time": "2021-01-02T00:00:00", + "web_link": "http://example.com", + "is_draft": False, + "message_id": "1234", + "to_recipients": [{"email_address": "to@example.com", "name": "ToName"}], + "cc_recipients": [{"email_address": "cc@example.com", "name": "CcName"}], + "bcc_recipients": [{"email_address": "bcc@example.com", "name": "BccName"}], + "reply_to": [{"email_address": "reply@example.com", "name": "ReplyName"}], + "from_": {"email_address": "from@example.com", "name": "FromName"}, + "conversation_index_bytes": False, + }, + { + "body": "Hello world", + "subject": "Test subject", + "conversation_id": "conv-1", + "conversation_index": "conv-index", + "flag": {"flag_status": "flagged", "due_date_time": "2021-01-01T10:00:00"}, + "has_attachments": False, + "importance": "high", + "is_read": True, + "received_date_time": "2021-01-02T00:00:00", + "web_link": "http://example.com", + "is_draft": False, + "message_id": "1234", + "to_recipients": [{"email_address": "to@example.com", "name": "ToName"}], + "cc_recipients": [{"email_address": "cc@example.com", "name": "CcName"}], + "bcc_recipients": [{"email_address": "bcc@example.com", "name": "BccName"}], + "reply_to": [{"email_address": "reply@example.com", "name": "ReplyName"}], + "from_": {"email_address": "from@example.com", "name": "FromName"}, + }, + ), + ( + { + "body_content": "

Sample email message

", + "subject": "Another subject", + "conversation_id": "conv-2", + "conversation_index": b"byte-index", + "flag_status": "notFlaged", + "due_date_time": "", + "has_attachments": False, + "importance": "low", + "is_read": False, + "received_date_time": "", + "web_link": "", + "is_draft": True, + "message_id": "5678", + "to_recipients": [{"email_address": "user1@example.com", "name": "User1"}], + "cc_recipients": [], + "bcc_recipients": [], + "reply_to": [], + "from_": {"email_address": "sender@example.com", "name": "Sender"}, + "conversation_index_bytes": True, + }, + { + "body": "Sample email message", + "subject": "Another subject", + "conversation_id": "conv-2", + "conversation_index": "byte-index", + "flag": {"flag_status": "notFlaged", "due_date_time": ""}, + "has_attachments": False, + "importance": "low", + "is_read": False, + "received_date_time": "", + "web_link": "", + "is_draft": True, + "message_id": "5678", + "to_recipients": [{"email_address": "user1@example.com", "name": "User1"}], + "cc_recipients": [], + "bcc_recipients": [], + "reply_to": [], + "from_": {"email_address": "sender@example.com", "name": "Sender"}, + }, + ), + ], +) +def test_message_conversion(input_data, expected): + # Set up sdk message + sdk_message = GraphMessage() + sdk_message.body = ( + DummyBody(input_data["body_content"]) if "body_content" in input_data else None + ) + sdk_message.subject = input_data["subject"] + sdk_message.conversation_id = input_data["conversation_id"] + sdk_message.conversation_index = input_data["conversation_index"] + sdk_message.flag = ( + DummyFlag(input_data["flag_status"], input_data["due_date_time"]) + if "flag_status" in input_data + else None + ) + sdk_message.has_attachments = input_data["has_attachments"] + sdk_message.importance = DummyImportance(input_data["importance"]) + sdk_message.is_read = input_data["is_read"] + sdk_message.received_date_time = ( + DummyDateTime(input_data["received_date_time"]) + if input_data["received_date_time"] + else None + ) + sdk_message.web_link = input_data["web_link"] + sdk_message.is_draft = input_data["is_draft"] + sdk_message.id = input_data["message_id"] + sdk_message.to_recipients = [make_graph_recipient(r) for r in input_data["to_recipients"]] + sdk_message.cc_recipients = [make_graph_recipient(r) for r in input_data["cc_recipients"]] + sdk_message.bcc_recipients = [make_graph_recipient(r) for r in input_data["bcc_recipients"]] + sdk_message.reply_to = [make_graph_recipient(r) for r in input_data["reply_to"]] + sdk_message.from_ = make_graph_recipient(input_data["from_"]) + + # Convert to Arcade Message type + message = Message.from_sdk(sdk_message) + + # Ensure conversion is correct + assert message.body == expected["body"], "Body conversion mismatch" + assert message.subject == expected["subject"] + assert message.conversation_id == expected["conversation_id"] + assert message.conversation_index == expected["conversation_index"] + assert message.flag == expected["flag"] + assert message.has_attachments == expected["has_attachments"] + assert message.importance == expected["importance"] + assert message.is_read == expected["is_read"] + assert message.received_date_time == expected["received_date_time"] + assert message.web_link == expected["web_link"] + assert message.is_draft == expected["is_draft"] + assert message.message_id == expected["message_id"] + assert message.from_.email_address == expected["from_"]["email_address"] + assert message.from_.name == expected["from_"]["name"] + + def check_recipient_list(actual, exp_list): + assert len(actual) == len(exp_list) + for rec, exp in zip(actual, exp_list, strict=False): + assert rec.email_address == exp["email_address"] + assert rec.name == exp["name"] + + check_recipient_list(message.to_recipients, expected["to_recipients"]) + check_recipient_list(message.cc_recipients, expected["cc_recipients"]) + check_recipient_list(message.bcc_recipients, expected["bcc_recipients"]) + check_recipient_list(message.reply_to, expected["reply_to"]) + + +@pytest.mark.parametrize( + "initial, add_params, expected_to_recipients", + [ + # Add a "To" recipient + ( + {"to_recipients": []}, + {"to_add": ["new@example.com"]}, + [{"email_address": "new@example.com", "name": ""}], + ), + # Add a "To" recipient that already exists + ( + {"to_recipients": [{"email_address": "dup@example.com", "name": ""}]}, + {"to_add": ["dup@example.com"]}, + [ + {"email_address": "dup@example.com", "name": ""}, + ], + ), + # Remove a "To" recipient + ( + { + "to_recipients": [ + {"email_address": "a@example.com", "name": "A"}, + {"email_address": "b@example.com", "name": "B"}, + ] + }, + {"to_remove": ["a@example.com"]}, + [{"email_address": "b@example.com", "name": "B"}], + ), + # Add and remove a "To" recipient + ( + {"to_recipients": [{"email_address": "c@example.com", "name": "C"}]}, + {"to_add": ["d@example.com", "c@example.com"], "to_remove": ["c@example.com"]}, + [{"email_address": "d@example.com", "name": ""}], + ), + ], +) +def test_update_recipient_lists(initial, add_params, expected_to_recipients): + msg = Message() + msg.to_recipients = [ + Recipient(email_address=r["email_address"], name=r.get("name", "")) + for r in initial.get("to_recipients", []) + ] + msg.update_recipient_lists( + to_add=add_params.get("to_add"), to_remove=add_params.get("to_remove") + ) + result = [r.to_dict() for r in msg.to_recipients] + assert result == expected_to_recipients, f"Expected {expected_to_recipients}, got {result}" diff --git a/toolkits/outlook_mail/tests/test_recipient.py b/toolkits/outlook_mail/tests/test_recipient.py new file mode 100644 index 00000000..a039064f --- /dev/null +++ b/toolkits/outlook_mail/tests/test_recipient.py @@ -0,0 +1,43 @@ +import pytest +from msgraph.generated.models.email_address import EmailAddress as GraphEmailAddress +from msgraph.generated.models.recipient import Recipient as GraphRecipient + +from arcade_outlook_mail.message import Recipient + + +@pytest.mark.parametrize( + "input_sdk_recipient, expected_email, expected_name", + [ + ( + GraphRecipient(email_address=GraphEmailAddress(address="dev@arcade.dev", name="Dev")), + "dev@arcade.dev", + "Dev", + ), + ( + GraphRecipient(email_address=GraphEmailAddress(address="dev@arcade.dev")), + "dev@arcade.dev", + "", + ), + (GraphRecipient(email_address=GraphEmailAddress(name="Dev")), "", "Dev"), + (GraphRecipient(email_address=GraphEmailAddress()), "", ""), + (GraphRecipient(), "", ""), + ], +) +def test_recipient(input_sdk_recipient, expected_email, expected_name): + recipient = Recipient.from_sdk(input_sdk_recipient) + assert ( + recipient.email_address == expected_email + ), "SDK conversion didn't set email_address correctly" + assert recipient.name == expected_name, "SDK conversion didn't set name correctly" + + recipient_dict = recipient.to_dict() + expected_dict = {"email_address": expected_email, "name": expected_name} + assert recipient_dict == expected_dict, "to_dict conversion did not produce expected dictionary" + + actual_sdk_recipient = recipient.to_sdk() + assert ( + actual_sdk_recipient.email_address.address == expected_email + ), "to_sdk conversion produced wrong email address" + assert ( + actual_sdk_recipient.email_address.name == expected_name + ), "to_sdk conversion produced wrong name" diff --git a/toolkits/outlook_mail/tests/test_utils.py b/toolkits/outlook_mail/tests/test_utils.py new file mode 100644 index 00000000..563adad8 --- /dev/null +++ b/toolkits/outlook_mail/tests/test_utils.py @@ -0,0 +1,55 @@ +import pytest + +from arcade_outlook_mail._utils import _create_filter_expression +from arcade_outlook_mail.enums import EmailFilterProperty, FilterOperator + + +@pytest.mark.parametrize( + "property_, operator, value, expected_filter_expr", + [ + ( + EmailFilterProperty.SUBJECT, + FilterOperator.EQUAL, + "Hello", + "receivedDateTime ge 1900-01-01T00:00:00Z and subject eq 'Hello'", + ), + ( + EmailFilterProperty.SUBJECT, + FilterOperator.STARTS_WITH, + "He", + "receivedDateTime ge 1900-01-01T00:00:00Z and startsWith(subject, 'He')", + ), + ( + EmailFilterProperty.CONVERSATION_ID, + FilterOperator.EQUAL, + "12345askdfjh=wef67890", + "receivedDateTime ge 1900-01-01T00:00:00Z and conversationId eq '12345askdfjh=wef67890'", + ), + ( + EmailFilterProperty.CONVERSATION_ID, + FilterOperator.NOT_EQUAL, + "67890", + "receivedDateTime ge 1900-01-01T00:00:00Z and conversationId ne 67890", + ), + ( + EmailFilterProperty.RECEIVED_DATE_TIME, + FilterOperator.GREATER_THAN, + "2024-01-01", + "receivedDateTime gt '2024-01-01'", + ), + ( + EmailFilterProperty.SENDER, + FilterOperator.EQUAL, + "a@ex.com", + "receivedDateTime ge 1900-01-01T00:00:00Z and sender/emailAddress/address eq 'a@ex.com'", + ), + ( + EmailFilterProperty.SENDER, + FilterOperator.CONTAINS, + "joe", + "receivedDateTime ge 1900-01-01T00:00:00Z and contains(sender/emailAddress/address, 'joe')", + ), + ], +) +def test_create_filter_expression(property_, operator, value, expected_filter_expr): + assert _create_filter_expression(property_, operator, value) == expected_filter_expr diff --git a/toolkits/walmart/.pre-commit-config.yaml b/toolkits/walmart/.pre-commit-config.yaml new file mode 100644 index 00000000..8508a588 --- /dev/null +++ b/toolkits/walmart/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/walmart/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/walmart/.ruff.toml b/toolkits/walmart/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/walmart/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/walmart/LICENSE b/toolkits/walmart/LICENSE new file mode 100644 index 00000000..45f53e20 --- /dev/null +++ b/toolkits/walmart/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Arcade + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/walmart/Makefile b/toolkits/walmart/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/walmart/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/walmart/arcade_walmart/__init__.py b/toolkits/walmart/arcade_walmart/__init__.py new file mode 100644 index 00000000..4049a592 --- /dev/null +++ b/toolkits/walmart/arcade_walmart/__init__.py @@ -0,0 +1,3 @@ +from arcade_walmart.tools import search_products + +__all__ = ["search_products"] diff --git a/toolkits/walmart/arcade_walmart/enums.py b/toolkits/walmart/arcade_walmart/enums.py new file mode 100644 index 00000000..c1d2b600 --- /dev/null +++ b/toolkits/walmart/arcade_walmart/enums.py @@ -0,0 +1,21 @@ +from enum import Enum + + +class WalmartSortBy(Enum): + RELEVANCE = "relevance_according_to_keywords_searched" + PRICE_LOW_TO_HIGH = "lowest_price_first" + PRICE_HIGH_TO_LOW = "highest_price_first" + BEST_SELLING = "best_selling_products_first" + RATING_HIGH = "highest_rating_first" + NEW_ARRIVALS = "new_arrivals_first" + + def to_api_value(self: "WalmartSortBy") -> str | None: + _map = { + str(self.RELEVANCE): None, + str(self.PRICE_LOW_TO_HIGH): "price_low", + str(self.PRICE_HIGH_TO_LOW): "price_high", + str(self.BEST_SELLING): "best_seller", + str(self.RATING_HIGH): "rating_high", + str(self.NEW_ARRIVALS): "new", + } + return _map[str(self)] diff --git a/toolkits/walmart/arcade_walmart/tools/__init__.py b/toolkits/walmart/arcade_walmart/tools/__init__.py new file mode 100644 index 00000000..1eeefa91 --- /dev/null +++ b/toolkits/walmart/arcade_walmart/tools/__init__.py @@ -0,0 +1,3 @@ +from arcade_walmart.tools.walmart import search_products + +__all__ = ["search_products"] diff --git a/toolkits/walmart/arcade_walmart/tools/walmart.py b/toolkits/walmart/arcade_walmart/tools/walmart.py new file mode 100644 index 00000000..59e4ad0f --- /dev/null +++ b/toolkits/walmart/arcade_walmart/tools/walmart.py @@ -0,0 +1,95 @@ +from typing import Annotated, Any + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from arcade_tdk.tool import tool + +from arcade_walmart.enums import WalmartSortBy +from arcade_walmart.utils import ( + call_serpapi, + extract_walmart_product_details, + extract_walmart_results, + get_walmart_last_page_integer, + prepare_params, +) + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_products( + context: ToolContext, + keywords: Annotated[str, "Keywords to search for. E.g. 'apple iphone' or 'samsung galaxy'"], + sort_by: Annotated[ + WalmartSortBy, + "Sort the results by the specified criteria. " + f"Defaults to '{WalmartSortBy.RELEVANCE.value}'.", + ] = WalmartSortBy.RELEVANCE, + min_price: Annotated[ + float | None, + "Minimum price to filter the results by. E.g. 100.00", + ] = None, + max_price: Annotated[ + float | None, + "Maximum price to filter the results by. E.g. 100.00", + ] = None, + next_day_delivery: Annotated[ + bool, + "Filters products that are eligible for next day delivery. " + "Defaults to False (returns all products, regardless of delivery status).", + ] = False, + page: Annotated[ + int, + "Page number to fetch. Defaults to 1 (first page of results). " + "The maximum page value is 100.", + ] = 1, +) -> Annotated[dict[str, Any], "List of Walmart products matching the search query."]: + """Search Walmart products using SerpAPI.""" + if page > 100: + raise ToolExecutionError(f"The maximum page value for Walmart search is 100, got {page}.") + + sort_by_value = sort_by.to_api_value() + + params = prepare_params( + "walmart", + query=keywords, + sort=sort_by_value, + # When the user selects a sorting option, we have to disable the relevance sorting + # using the soft_sort parameter. + soft_sort=not sort_by_value, + min_price=min_price, + max_price=max_price, + nd_en=next_day_delivery, + page=page, + include_filters=False, + ) + + response = call_serpapi(context, params) + + return { + "products": extract_walmart_results(response.get("organic_results", [])), + "current_page": page, + "last_available_page": get_walmart_last_page_integer(response), + } + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_product_details( + context: ToolContext, + item_id: Annotated[ + str, + "Item ID. E.g. '414600577'. This can be retrieved from the search results of the " + f"{search_products.__tool_name__} tool.", + ], +) -> Annotated[dict[str, Any], "Product details"]: + """Get product details from Walmart.""" + params = prepare_params("walmart_product", product_id=item_id) + response = call_serpapi(context, params) + + product_result = response.get("product_result") + + if not product_result: + return { + "product_details": None, + "message": f"No product details found for item ID '{item_id}'.", + } + + return {"product_details": extract_walmart_product_details(product_result)} diff --git a/toolkits/walmart/arcade_walmart/utils.py b/toolkits/walmart/arcade_walmart/utils.py new file mode 100644 index 00000000..252e80fc --- /dev/null +++ b/toolkits/walmart/arcade_walmart/utils.py @@ -0,0 +1,120 @@ +import re +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) + + +def extract_walmart_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [ + { + "item_id": result.get("us_item_id"), + "title": result.get("title"), + "description": result.get("description"), + "rating": result.get("rating"), + "reviews_count": result.get("reviews"), + "seller": { + "id": result.get("seller_id"), + "name": result.get("seller_name"), + }, + "price": { + "value": result.get("primary_offer", {}).get("offer_price"), + "currency": result.get("primary_offer", {}).get("offer_currency"), + }, + "link": result.get("product_page_url"), + } + for result in results + ] + + +def get_walmart_last_page_integer(results: dict[str, Any]) -> int: + try: + return int(list(results["pagination"]["other_pages"].keys())[-1]) + except (KeyError, IndexError, ValueError): + return 1 + + +def extract_walmart_product_details(product: dict[str, Any]) -> dict[str, Any]: + return { + "item_id": product.get("us_item_id"), + "product_type": product.get("product_type"), + "title": product.get("title"), + "description_html": product.get("short_description_html"), + "rating": product.get("rating"), + "reviews_count": product.get("reviews"), + "seller": { + "id": product.get("seller_id"), + "name": product.get("seller_name"), + }, + "manufacturer_name": product.get("manufacturer"), + "price": { + "value": product.get("price_map", {}).get("price"), + "currency": product.get("price_map", {}).get("currency"), + "previous_price": product.get("price_map", {}).get("was_price", {}).get("price"), + }, + "link": product.get("product_page_url"), + "variant_options": extract_walmart_variant_options(product.get("variant_swatches", [])), + } + + +def extract_walmart_variant_options(variant_swatches: list[dict[str, Any]]) -> list[dict[str, Any]]: + variants = [] + + for variant_swatch in variant_swatches: + variant_name = variant_swatch.get("name") + if not variant_name: + continue + + options = [] + + for selection in variant_swatch.get("available_selections", []): + selection_name = selection.get("name") + if selection_name and selection_name not in options: + options.append(selection_name) + + variants.append({variant_name: options}) + + return variants diff --git a/toolkits/walmart/pyproject.toml b/toolkits/walmart/pyproject.toml new file mode 100644 index 00000000..b391d9ea --- /dev/null +++ b/toolkits/walmart/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_walmart" +version = "2.0.0" +description = "Arcade.dev LLM tools for searching for products sold by Walmart" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "serpapi>=0.1.5,<1.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_walmart/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_walmart",] diff --git a/toolkits/youtube/.pre-commit-config.yaml b/toolkits/youtube/.pre-commit-config.yaml new file mode 100644 index 00000000..52c11167 --- /dev/null +++ b/toolkits/youtube/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/youtube/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/toolkits/youtube/.ruff.toml b/toolkits/youtube/.ruff.toml new file mode 100644 index 00000000..f1aed90f --- /dev/null +++ b/toolkits/youtube/.ruff.toml @@ -0,0 +1,46 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + +[format] +preview = true +skip-magic-trailing-comma = false diff --git a/toolkits/youtube/LICENSE b/toolkits/youtube/LICENSE new file mode 100644 index 00000000..45f53e20 --- /dev/null +++ b/toolkits/youtube/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Arcade + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/toolkits/youtube/Makefile b/toolkits/youtube/Makefile new file mode 100644 index 00000000..0a8969be --- /dev/null +++ b/toolkits/youtube/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/youtube/arcade_youtube/__init__.py b/toolkits/youtube/arcade_youtube/__init__.py new file mode 100644 index 00000000..5ef6eeeb --- /dev/null +++ b/toolkits/youtube/arcade_youtube/__init__.py @@ -0,0 +1,3 @@ +from arcade_youtube.tools import get_youtube_video_details, search_for_videos + +__all__ = ["get_youtube_video_details", "search_for_videos"] diff --git a/toolkits/youtube/arcade_youtube/constants.py b/toolkits/youtube/arcade_youtube/constants.py new file mode 100644 index 00000000..db8b6670 --- /dev/null +++ b/toolkits/youtube/arcade_youtube/constants.py @@ -0,0 +1,7 @@ +import os + +YOUTUBE_MAX_DESCRIPTION_LENGTH = 500 +DEFAULT_YOUTUBE_SEARCH_LANGUAGE = os.getenv("ARCADE_YOUTUBE_SEARCH_LANGUAGE") +DEFAULT_YOUTUBE_SEARCH_COUNTRY = os.getenv("ARCADE_YOUTUBE_SEARCH_COUNTRY") +DEFAULT_GOOGLE_LANGUAGE = os.getenv("ARCADE_GOOGLE_LANGUAGE", "en") +DEFAULT_GOOGLE_COUNTRY = os.getenv("ARCADE_GOOGLE_COUNTRY") diff --git a/toolkits/youtube/arcade_youtube/exceptions.py b/toolkits/youtube/arcade_youtube/exceptions.py new file mode 100644 index 00000000..51d3f809 --- /dev/null +++ b/toolkits/youtube/arcade_youtube/exceptions.py @@ -0,0 +1,25 @@ +import json + +from arcade_tdk.errors import RetryableToolError + +from arcade_youtube.google_data import COUNTRY_CODES, LANGUAGE_CODES + + +class GoogleRetryableError(RetryableToolError): + pass + + +class CountryNotFoundError(GoogleRetryableError): + def __init__(self, country: str | None) -> None: + valid_countries = json.dumps(COUNTRY_CODES, default=str) + message = f"Country not found: '{country}'." + additional_message = f"Valid countries are: {valid_countries}" + super().__init__(message, additional_prompt_content=additional_message) + + +class LanguageNotFoundError(GoogleRetryableError): + def __init__(self, language: str | None) -> None: + valid_languages = json.dumps(LANGUAGE_CODES, default=str) + message = f"Language not found: '{language}'." + additional_message = f"Valid languages are: {valid_languages}" + super().__init__(message, additional_prompt_content=additional_message) diff --git a/toolkits/youtube/arcade_youtube/google_data.py b/toolkits/youtube/arcade_youtube/google_data.py new file mode 100644 index 00000000..789e3183 --- /dev/null +++ b/toolkits/youtube/arcade_youtube/google_data.py @@ -0,0 +1,281 @@ +COUNTRY_CODES = { + "af": "Afghanistan", + "al": "Albania", + "dz": "Algeria", + "as": "American Samoa", + "ad": "Andorra", + "ao": "Angola", + "ai": "Anguilla", + "aq": "Antarctica", + "ag": "Antigua and Barbuda", + "ar": "Argentina", + "am": "Armenia", + "aw": "Aruba", + "au": "Australia", + "at": "Austria", + "az": "Azerbaijan", + "bs": "Bahamas", + "bh": "Bahrain", + "bd": "Bangladesh", + "bb": "Barbados", + "by": "Belarus", + "be": "Belgium", + "bz": "Belize", + "bj": "Benin", + "bm": "Bermuda", + "bt": "Bhutan", + "bo": "Bolivia", + "ba": "Bosnia and Herzegovina", + "bw": "Botswana", + "bv": "Bouvet Island", + "br": "Brazil", + "io": "British Indian Ocean Territory", + "bn": "Brunei Darussalam", + "bg": "Bulgaria", + "bf": "Burkina Faso", + "bi": "Burundi", + "kh": "Cambodia", + "cm": "Cameroon", + "ca": "Canada", + "cv": "Cape Verde", + "ky": "Cayman Islands", + "cf": "Central African Republic", + "td": "Chad", + "cl": "Chile", + "cn": "China", + "cx": "Christmas Island", + "cc": "Cocos (Keeling) Islands", + "co": "Colombia", + "km": "Comoros", + "cg": "Congo", + "cd": "Congo, the Democratic Republic of the", + "ck": "Cook Islands", + "cr": "Costa Rica", + "ci": "Cote D'ivoire", + "hr": "Croatia", + "cu": "Cuba", + "cy": "Cyprus", + "cz": "Czech Republic", + "dk": "Denmark", + "dj": "Djibouti", + "dm": "Dominica", + "do": "Dominican Republic", + "ec": "Ecuador", + "eg": "Egypt", + "sv": "El Salvador", + "gq": "Equatorial Guinea", + "er": "Eritrea", + "ee": "Estonia", + "et": "Ethiopia", + "fk": "Falkland Islands (Malvinas)", + "fo": "Faroe Islands", + "fj": "Fiji", + "fi": "Finland", + "fr": "France", + "gf": "French Guiana", + "pf": "French Polynesia", + "tf": "French Southern Territories", + "ga": "Gabon", + "gm": "Gambia", + "ge": "Georgia", + "de": "Germany", + "gh": "Ghana", + "gi": "Gibraltar", + "gr": "Greece", + "gl": "Greenland", + "gd": "Grenada", + "gp": "Guadeloupe", + "gu": "Guam", + "gt": "Guatemala", + "gg": "Guernsey", + "gn": "Guinea", + "gw": "Guinea-Bissau", + "gy": "Guyana", + "ht": "Haiti", + "hm": "Heard Island and Mcdonald Islands", + "va": "Holy See (Vatican City State)", + "hn": "Honduras", + "hk": "Hong Kong", + "hu": "Hungary", + "is": "Iceland", + "in": "India", + "id": "Indonesia", + "ir": "Iran, Islamic Republic of", + "iq": "Iraq", + "ie": "Ireland", + "im": "Isle of Man", + "il": "Israel", + "it": "Italy", + "je": "Jersey", + "jm": "Jamaica", + "jp": "Japan", + "jo": "Jordan", + "kz": "Kazakhstan", + "ke": "Kenya", + "ki": "Kiribati", + "kp": "Korea, Democratic People's Republic of", + "kr": "Korea, Republic of", + "kw": "Kuwait", + "kg": "Kyrgyzstan", + "la": "Lao People's Democratic Republic", + "lv": "Latvia", + "lb": "Lebanon", + "ls": "Lesotho", + "lr": "Liberia", + "ly": "Libyan Arab Jamahiriya", + "li": "Liechtenstein", + "lt": "Lithuania", + "lu": "Luxembourg", + "mo": "Macao", + "mk": "Macedonia, the Former Yugosalv Republic of", + "mg": "Madagascar", + "mw": "Malawi", + "my": "Malaysia", + "mv": "Maldives", + "ml": "Mali", + "mt": "Malta", + "mh": "Marshall Islands", + "mq": "Martinique", + "mr": "Mauritania", + "mu": "Mauritius", + "yt": "Mayotte", + "mx": "Mexico", + "fm": "Micronesia, Federated States of", + "md": "Moldova, Republic of", + "mc": "Monaco", + "mn": "Mongolia", + "me": "Montenegro", + "ms": "Montserrat", + "ma": "Morocco", + "mz": "Mozambique", + "mm": "Myanmar", + "na": "Namibia", + "nr": "Nauru", + "np": "Nepal", + "nl": "Netherlands", + "an": "Netherlands Antilles", + "nc": "New Caledonia", + "nz": "New Zealand", + "ni": "Nicaragua", + "ne": "Niger", + "ng": "Nigeria", + "nu": "Niue", + "nf": "Norfolk Island", + "mp": "Northern Mariana Islands", + "no": "Norway", + "om": "Oman", + "pk": "Pakistan", + "pw": "Palau", + "ps": "Palestinian Territory, Occupied", + "pa": "Panama", + "pg": "Papua New Guinea", + "py": "Paraguay", + "pe": "Peru", + "ph": "Philippines", + "pn": "Pitcairn", + "pl": "Poland", + "pt": "Portugal", + "pr": "Puerto Rico", + "qa": "Qatar", + "re": "Reunion", + "ro": "Romania", + "ru": "Russian Federation", + "rw": "Rwanda", + "sh": "Saint Helena", + "kn": "Saint Kitts and Nevis", + "lc": "Saint Lucia", + "pm": "Saint Pierre and Miquelon", + "vc": "Saint Vincent and the Grenadines", + "ws": "Samoa", + "sm": "San Marino", + "st": "Sao Tome and Principe", + "sa": "Saudi Arabia", + "sn": "Senegal", + "rs": "Serbia", + "sc": "Seychelles", + "sl": "Sierra Leone", + "sg": "Singapore", + "sk": "Slovakia", + "si": "Slovenia", + "sb": "Solomon Islands", + "so": "Somalia", + "za": "South Africa", + "gs": "South Georgia and the South Sandwich Islands", + "es": "Spain", + "lk": "Sri Lanka", + "sd": "Sudan", + "sr": "Suriname", + "sj": "Svalbard and Jan Mayen", + "sz": "Swaziland", + "se": "Sweden", + "ch": "Switzerland", + "sy": "Syrian Arab Republic", + "tw": "Taiwan, Province of China", + "tj": "Tajikistan", + "tz": "Tanzania, United Republic of", + "th": "Thailand", + "tl": "Timor-Leste", + "tg": "Togo", + "tk": "Tokelau", + "to": "Tonga", + "tt": "Trinidad and Tobago", + "tn": "Tunisia", + "tr": "Turkiye", + "tm": "Turkmenistan", + "tc": "Turks and Caicos Islands", + "tv": "Tuvalu", + "ug": "Uganda", + "ua": "Ukraine", + "ae": "United Arab Emirates", + "uk": "United Kingdom", + "gb": "United Kingdom", + "us": "United States", + "um": "United States Minor Outlying Islands", + "uy": "Uruguay", + "uz": "Uzbekistan", + "vu": "Vanuatu", + "ve": "Venezuela", + "vn": "Viet Nam", + "vg": "Virgin Islands, British", + "vi": "Virgin Islands, U.S.", + "wf": "Wallis and Futuna", + "eh": "Western Sahara", + "ye": "Yemen", + "zm": "Zambia", + "zw": "Zimbabwe", +} + + +LANGUAGE_CODES = { + "ar": "Arabic", + "bn": "Bengali", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "fi": "Finnish", + "fr": "French", + "hi": "Hindi", + "hu": "Hungarian", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "ko": "Korean", + "nl": "Dutch", + "ms": "Malay", + "no": "Norwegian", + "pcm": "Nigerian Pidgin", + "pl": "Polish", + "pt": "Portuguese", + "pt-br": "Portuguese (Brazil)", + "pt-pt": "Portuguese (Portugal)", + "ru": "Russian", + "sv": "Swedish", + "tl": "Filipino", + "tr": "Turkish", + "uk": "Ukrainian", + "zh": "Chinese", + "zh-cn": "Chinese (Simplified)", + "zh-tw": "Chinese (Traditional)", +} diff --git a/toolkits/youtube/arcade_youtube/tools/__init__.py b/toolkits/youtube/arcade_youtube/tools/__init__.py new file mode 100644 index 00000000..d4ce19c4 --- /dev/null +++ b/toolkits/youtube/arcade_youtube/tools/__init__.py @@ -0,0 +1,3 @@ +from arcade_youtube.tools.youtube import get_youtube_video_details, search_for_videos + +__all__ = ["get_youtube_video_details", "search_for_videos"] diff --git a/toolkits/youtube/arcade_youtube/tools/youtube.py b/toolkits/youtube/arcade_youtube/tools/youtube.py new file mode 100644 index 00000000..5f049790 --- /dev/null +++ b/toolkits/youtube/arcade_youtube/tools/youtube.py @@ -0,0 +1,101 @@ +from typing import Annotated, Any, cast + +from arcade_tdk import ToolContext, tool +from arcade_tdk.errors import ToolExecutionError + +from arcade_youtube.constants import DEFAULT_YOUTUBE_SEARCH_COUNTRY, DEFAULT_YOUTUBE_SEARCH_LANGUAGE +from arcade_youtube.utils import ( + call_serpapi, + default_country_code, + default_language_code, + extract_video_details, + extract_video_results, + prepare_params, + resolve_country_code, + resolve_language_code, +) + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_for_videos( + context: ToolContext, + keywords: Annotated[ + str, + "The keywords to search for. E.g. 'Python tutorial'.", + ], + language_code: Annotated[ + str | None, + "2-character language code to search for. E.g. 'en' for English. " + f"Defaults to '{default_language_code(DEFAULT_YOUTUBE_SEARCH_LANGUAGE)}'.", + ] = None, + country_code: Annotated[ + str | None, + "2-character country code to search for. E.g. 'us' for United States. " + f"Defaults to '{default_country_code(DEFAULT_YOUTUBE_SEARCH_COUNTRY)}'.", + ] = None, + next_page_token: Annotated[ + str | None, + "The next page token to use for pagination. " + "Defaults to `None` (start from the first page).", + ] = None, +) -> Annotated[dict[str, Any], "List of YouTube videos related to the query."]: + """Search for YouTube videos related to the query.""" + language_code = resolve_language_code(language_code, DEFAULT_YOUTUBE_SEARCH_LANGUAGE) + country_code = resolve_country_code(country_code, DEFAULT_YOUTUBE_SEARCH_COUNTRY) + + params = prepare_params( + "youtube", + search_query=keywords, + hl=language_code, + gl=country_code, + sp=next_page_token, + ) + results = call_serpapi(context, params) + + if results.get("error"): + error_msg = cast(str, results.get("error")) + raise ToolExecutionError(error_msg) + + return { + "videos": extract_video_results(results), + "next_page_token": results.get("serpapi_pagination", {}).get("next_page_token"), + } + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_youtube_video_details( + context: ToolContext, + video_id: Annotated[ + str, + "The ID of the YouTube video to get details about. E.g. 'dQw4w9WgXcQ'.", + ], + language_code: Annotated[ + str | None, + "2-character language code to search for. E.g. 'en' for English. " + f"Defaults to '{default_language_code(DEFAULT_YOUTUBE_SEARCH_LANGUAGE)}'.", + ] = None, + country_code: Annotated[ + str | None, + "2-character country code to search for. E.g. 'us' for United States. " + f"Defaults to '{default_country_code(DEFAULT_YOUTUBE_SEARCH_COUNTRY)}'.", + ] = None, +) -> Annotated[dict[str, Any], "Details about a YouTube video."]: + """Get details about a YouTube video.""" + language_code = resolve_language_code(language_code, DEFAULT_YOUTUBE_SEARCH_LANGUAGE) + country_code = resolve_country_code(country_code, DEFAULT_YOUTUBE_SEARCH_COUNTRY) + + params = prepare_params( + "youtube_video", + v=video_id, + hl=language_code, + gl=country_code, + ) + results = call_serpapi(context, params) + + if results.get("error"): + error_msg = cast(str, results.get("error")) + raise ToolExecutionError(error_msg) + + return { + "video": extract_video_details(results), + } diff --git a/toolkits/youtube/arcade_youtube/utils.py b/toolkits/youtube/arcade_youtube/utils.py new file mode 100644 index 00000000..a16c52cc --- /dev/null +++ b/toolkits/youtube/arcade_youtube/utils.py @@ -0,0 +1,169 @@ +import re +from typing import Any, cast +from urllib.parse import parse_qs, urlparse + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from serpapi import Client as SerpClient + +from arcade_youtube.constants import ( + DEFAULT_GOOGLE_COUNTRY, + DEFAULT_GOOGLE_LANGUAGE, + YOUTUBE_MAX_DESCRIPTION_LENGTH, +) +from arcade_youtube.exceptions import CountryNotFoundError, LanguageNotFoundError +from arcade_youtube.google_data import COUNTRY_CODES, LANGUAGE_CODES + + +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = SerpClient(api_key=api_key) + try: + search = client.search(params) + return cast(dict[str, Any], search.as_dict()) + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) + + +def default_language_code(default_service_language_code: str | None = None) -> str | None: + if isinstance(default_service_language_code, str): + return default_service_language_code.lower() + elif isinstance(DEFAULT_GOOGLE_LANGUAGE, str): + return DEFAULT_GOOGLE_LANGUAGE.lower() + return None + + +def default_country_code(default_service_country_code: str | None = None) -> str | None: + if isinstance(default_service_country_code, str): + return default_service_country_code.lower() + elif isinstance(DEFAULT_GOOGLE_COUNTRY, str): + return DEFAULT_GOOGLE_COUNTRY.lower() + return None + + +def resolve_language_code( + language_code: str | None = None, + default_service_language_code: str | None = None, +) -> str | None: + language_code = language_code or default_language_code(default_service_language_code) + + if isinstance(language_code, str): + language_code = language_code.lower() + if language_code not in LANGUAGE_CODES: + raise LanguageNotFoundError(language_code) + + return language_code + + +def resolve_country_code( + country_code: str | None = None, + default_service_country_code: str | None = None, +) -> str | None: + country_code = country_code or default_country_code(default_service_country_code) + + if isinstance(country_code, str): + country_code = country_code.lower() + if country_code not in COUNTRY_CODES: + raise CountryNotFoundError(country_code) + + return country_code + + +def extract_video_id_from_link(link: str | None) -> str | None: + if not isinstance(link, str): + return None + + parsed_url = urlparse(link) + query_params = parse_qs(parsed_url.query) + return query_params.get("v", [""])[0] + + +def extract_video_description( + video: dict[str, Any], + max_description_length: int = YOUTUBE_MAX_DESCRIPTION_LENGTH, +) -> str | None: + description = video.get("description", "") + + if isinstance(description, dict): + description = description.get("content", "") + + if isinstance(description, str): + too_long = len(description) > max_description_length + if too_long: + description = description[:max_description_length] + " [truncated]" + + if description is not None: + description = str(description).strip() + + return cast(str | None, description) + + +def extract_video_results( + results: dict[str, Any], + max_description_length: int = YOUTUBE_MAX_DESCRIPTION_LENGTH, +) -> list[dict[str, Any]]: + videos = [] + + for video in results.get("video_results", []): + videos.append({ + "id": extract_video_id_from_link(video.get("link")), + "title": video.get("title"), + "description": extract_video_description(video, max_description_length), + "link": video.get("link"), + "published_date": video.get("published_date"), + "duration": video.get("duration"), + "channel": { + "name": video.get("channel", {}).get("name"), + "link": video.get("channel", {}).get("link"), + }, + }) + + return videos + + +def extract_video_details(video: dict[str, Any]) -> dict[str, Any]: + return { + "id": extract_video_id_from_link(video.get("link")), + "title": video.get("title"), + "description": extract_video_description(video, YOUTUBE_MAX_DESCRIPTION_LENGTH), + "published_date": video.get("published_date"), + "channel": { + "name": video.get("channel", {}).get("name"), + "link": video.get("channel", {}).get("link"), + }, + "like_count": video.get("extracted_likes"), + "view_count": video.get("extracted_views"), + "live": video.get("live", False), + } diff --git a/toolkits/youtube/pyproject.toml b/toolkits/youtube/pyproject.toml new file mode 100644 index 00000000..fb552c8d --- /dev/null +++ b/toolkits/youtube/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_youtube" +version = "2.0.0" +description = "Arcade.dev LLM tools for searching for YouTube videos"" +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "serpapi>=0.1.5,<1.0.0", +] +[[project.authors]] +name = "Arcade" +email = "dev@arcade.dev" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.0,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + + +[tool.mypy] +files = [ "arcade_youtube/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_youtube",]