From f6765bed6792e73522e28e27b3e2ce5e991e03cf Mon Sep 17 00:00:00 2001 From: Renato Byrro Date: Fri, 21 Mar 2025 21:02:45 -0300 Subject: [PATCH] Walmart shopping search tools (#320) --- toolkits/search/arcade_search/enums.py | 24 +++++ .../search/arcade_search/tools/walmart.py | 95 +++++++++++++++++++ toolkits/search/arcade_search/utils.py | 75 +++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 toolkits/search/arcade_search/tools/walmart.py diff --git a/toolkits/search/arcade_search/enums.py b/toolkits/search/arcade_search/enums.py index 8c9fee94..876b8de4 100644 --- a/toolkits/search/arcade_search/enums.py +++ b/toolkits/search/arcade_search/enums.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional # ------------------------------------------------------------------------------------------------ @@ -124,3 +125,26 @@ class GoogleMapsDistanceUnit(Enum): str(self.MILES): 1, } return _map[str(self)] + + +# ------------------------------------------------------------------------------------------------ +# Walmart enumerations +# ------------------------------------------------------------------------------------------------ +class WalmartSortBy(Enum): + RELEVANCE = "relevance_according_to_keywords_searched" + PRICE_LOW_TO_HIGH = "lowest_price_first" + PRICE_HIGH_TO_LOW = "highest_price_first" + BEST_SELLING = "best_selling_products_first" + RATING_HIGH = "highest_rating_first" + NEW_ARRIVALS = "new_arrivals_first" + + def to_api_value(self: "WalmartSortBy") -> Optional[str]: + _map = { + str(self.RELEVANCE): None, + str(self.PRICE_LOW_TO_HIGH): "price_low", + str(self.PRICE_HIGH_TO_LOW): "price_high", + str(self.BEST_SELLING): "best_seller", + str(self.RATING_HIGH): "rating_high", + str(self.NEW_ARRIVALS): "new", + } + return _map[str(self)] diff --git a/toolkits/search/arcade_search/tools/walmart.py b/toolkits/search/arcade_search/tools/walmart.py new file mode 100644 index 00000000..3c0234f8 --- /dev/null +++ b/toolkits/search/arcade_search/tools/walmart.py @@ -0,0 +1,95 @@ +from typing import Annotated, Any, Optional + +from arcade.sdk import ToolContext +from arcade.sdk.errors import ToolExecutionError +from arcade.sdk.tool import tool + +from arcade_search.enums import WalmartSortBy +from arcade_search.utils import ( + call_serpapi, + extract_walmart_product_details, + extract_walmart_results, + get_walmart_last_page_integer, + prepare_params, +) + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_walmart_products( + context: ToolContext, + keywords: Annotated[str, "Keywords to search for. E.g. 'apple iphone' or 'samsung galaxy'"], + sort_by: Annotated[ + WalmartSortBy, + "Sort the results by the specified criteria. " + f"Defaults to '{WalmartSortBy.RELEVANCE.value}'.", + ] = WalmartSortBy.RELEVANCE, + min_price: Annotated[ + Optional[float], + "Minimum price to filter the results by. E.g. 100.00", + ] = None, + max_price: Annotated[ + Optional[float], + "Maximum price to filter the results by. E.g. 100.00", + ] = None, + next_day_delivery: Annotated[ + bool, + "Filters products that are eligible for next day delivery. " + "Defaults to False (returns all products, regardless of delivery status).", + ] = False, + page: Annotated[ + int, + "Page number to fetch. Defaults to 1 (first page of results). " + "The maximum page value is 100.", + ] = 1, +) -> Annotated[dict[str, Any], "List of Walmart products matching the search query."]: + """Search Walmart products using SerpAPI.""" + if page > 100: + raise ToolExecutionError(f"The maximum page value for Walmart search is 100, got {page}.") + + sort_by_value = sort_by.to_api_value() + + params = prepare_params( + "walmart", + query=keywords, + sort=sort_by_value, + # When the user selects a sorting option, we have to disable the relevance sorting + # using the soft_sort parameter. + soft_sort=not sort_by_value, + min_price=min_price, + max_price=max_price, + nd_en=next_day_delivery, + page=page, + include_filters=False, + ) + + response = call_serpapi(context, params) + + return { + "products": extract_walmart_results(response.get("organic_results", [])), + "current_page": page, + "last_available_page": get_walmart_last_page_integer(response), + } + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_walmart_product_details( + context: ToolContext, + item_id: Annotated[ + str, + "Item ID. E.g. '414600577'. This can be retrieved from the search results of the " + f"{search_walmart_products.__tool_name__} tool.", + ], +) -> Annotated[dict[str, Any], "Product details"]: + """Get product details from Walmart.""" + params = prepare_params("walmart_product", product_id=item_id) + response = call_serpapi(context, params) + + product_result = response.get("product_result") + + if not product_result: + return { + "product_details": None, + "message": f"No product details found for item ID '{item_id}'.", + } + + return {"product_details": extract_walmart_product_details(product_result)} diff --git a/toolkits/search/arcade_search/utils.py b/toolkits/search/arcade_search/utils.py index 1bdc860f..41901fc9 100644 --- a/toolkits/search/arcade_search/utils.py +++ b/toolkits/search/arcade_search/utils.py @@ -276,6 +276,81 @@ def extract_news_results( return news_results +# ------------------------------------------------------------------------------------------------ +# Walmart utils +# ------------------------------------------------------------------------------------------------ +def extract_walmart_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [ + { + "item_id": result.get("us_item_id"), + "title": result.get("title"), + "description": result.get("description"), + "rating": result.get("rating"), + "reviews_count": result.get("reviews"), + "seller": { + "id": result.get("seller_id"), + "name": result.get("seller_name"), + }, + "price": { + "value": result.get("primary_offer", {}).get("offer_price"), + "currency": result.get("primary_offer", {}).get("offer_currency"), + }, + "link": result.get("product_page_url"), + } + for result in results + ] + + +def get_walmart_last_page_integer(results: dict[str, Any]) -> int: + try: + return int(list(results["pagination"]["other_pages"].keys())[-1]) + except (KeyError, IndexError, ValueError): + return 1 + + +def extract_walmart_product_details(product: dict[str, Any]) -> dict[str, Any]: + return { + "item_id": product.get("us_item_id"), + "product_type": product.get("product_type"), + "title": product.get("title"), + "description_html": product.get("short_description_html"), + "rating": product.get("rating"), + "reviews_count": product.get("reviews"), + "seller": { + "id": product.get("seller_id"), + "name": product.get("seller_name"), + }, + "manufacturer_name": product.get("manufacturer"), + "price": { + "value": product.get("price_map", {}).get("price"), + "currency": product.get("price_map", {}).get("currency"), + "previous_price": product.get("price_map", {}).get("was_price", {}).get("price"), + }, + "link": product.get("product_page_url"), + "variant_options": extract_walmart_variant_options(product.get("variant_swatches", [])), + } + + +def extract_walmart_variant_options(variant_swatches: list[dict[str, Any]]) -> list[dict[str, Any]]: + variants = [] + + for variant_swatch in variant_swatches: + variant_name = variant_swatch.get("name") + if not variant_name: + continue + + options = [] + + for selection in variant_swatch.get("available_selections", []): + selection_name = selection.get("name") + if selection_name and selection_name not in options: + options.append(selection_name) + + variants.append({variant_name: options}) + + return variants + + # ------------------------------------------------------------------------------------------------ # YouTube utils # ------------------------------------------------------------------------------------------------