Walmart shopping search tools (#320)

This commit is contained in:
Renato Byrro 2025-03-21 21:02:45 -03:00 committed by GitHub
parent 227f02d2fd
commit f6765bed67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 194 additions and 0 deletions

View file

@ -1,4 +1,5 @@
from enum import Enum
from typing import Optional
# ------------------------------------------------------------------------------------------------
@ -124,3 +125,26 @@ class GoogleMapsDistanceUnit(Enum):
str(self.MILES): 1,
}
return _map[str(self)]
# ------------------------------------------------------------------------------------------------
# Walmart enumerations
# ------------------------------------------------------------------------------------------------
class WalmartSortBy(Enum):
RELEVANCE = "relevance_according_to_keywords_searched"
PRICE_LOW_TO_HIGH = "lowest_price_first"
PRICE_HIGH_TO_LOW = "highest_price_first"
BEST_SELLING = "best_selling_products_first"
RATING_HIGH = "highest_rating_first"
NEW_ARRIVALS = "new_arrivals_first"
def to_api_value(self: "WalmartSortBy") -> Optional[str]:
_map = {
str(self.RELEVANCE): None,
str(self.PRICE_LOW_TO_HIGH): "price_low",
str(self.PRICE_HIGH_TO_LOW): "price_high",
str(self.BEST_SELLING): "best_seller",
str(self.RATING_HIGH): "rating_high",
str(self.NEW_ARRIVALS): "new",
}
return _map[str(self)]

View file

@ -0,0 +1,95 @@
from typing import Annotated, Any, Optional
from arcade.sdk import ToolContext
from arcade.sdk.errors import ToolExecutionError
from arcade.sdk.tool import tool
from arcade_search.enums import WalmartSortBy
from arcade_search.utils import (
call_serpapi,
extract_walmart_product_details,
extract_walmart_results,
get_walmart_last_page_integer,
prepare_params,
)
@tool(requires_secrets=["SERP_API_KEY"])
async def search_walmart_products(
context: ToolContext,
keywords: Annotated[str, "Keywords to search for. E.g. 'apple iphone' or 'samsung galaxy'"],
sort_by: Annotated[
WalmartSortBy,
"Sort the results by the specified criteria. "
f"Defaults to '{WalmartSortBy.RELEVANCE.value}'.",
] = WalmartSortBy.RELEVANCE,
min_price: Annotated[
Optional[float],
"Minimum price to filter the results by. E.g. 100.00",
] = None,
max_price: Annotated[
Optional[float],
"Maximum price to filter the results by. E.g. 100.00",
] = None,
next_day_delivery: Annotated[
bool,
"Filters products that are eligible for next day delivery. "
"Defaults to False (returns all products, regardless of delivery status).",
] = False,
page: Annotated[
int,
"Page number to fetch. Defaults to 1 (first page of results). "
"The maximum page value is 100.",
] = 1,
) -> Annotated[dict[str, Any], "List of Walmart products matching the search query."]:
"""Search Walmart products using SerpAPI."""
if page > 100:
raise ToolExecutionError(f"The maximum page value for Walmart search is 100, got {page}.")
sort_by_value = sort_by.to_api_value()
params = prepare_params(
"walmart",
query=keywords,
sort=sort_by_value,
# When the user selects a sorting option, we have to disable the relevance sorting
# using the soft_sort parameter.
soft_sort=not sort_by_value,
min_price=min_price,
max_price=max_price,
nd_en=next_day_delivery,
page=page,
include_filters=False,
)
response = call_serpapi(context, params)
return {
"products": extract_walmart_results(response.get("organic_results", [])),
"current_page": page,
"last_available_page": get_walmart_last_page_integer(response),
}
@tool(requires_secrets=["SERP_API_KEY"])
async def get_walmart_product_details(
context: ToolContext,
item_id: Annotated[
str,
"Item ID. E.g. '414600577'. This can be retrieved from the search results of the "
f"{search_walmart_products.__tool_name__} tool.",
],
) -> Annotated[dict[str, Any], "Product details"]:
"""Get product details from Walmart."""
params = prepare_params("walmart_product", product_id=item_id)
response = call_serpapi(context, params)
product_result = response.get("product_result")
if not product_result:
return {
"product_details": None,
"message": f"No product details found for item ID '{item_id}'.",
}
return {"product_details": extract_walmart_product_details(product_result)}

View file

@ -276,6 +276,81 @@ def extract_news_results(
return news_results
# ------------------------------------------------------------------------------------------------
# Walmart utils
# ------------------------------------------------------------------------------------------------
def extract_walmart_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [
{
"item_id": result.get("us_item_id"),
"title": result.get("title"),
"description": result.get("description"),
"rating": result.get("rating"),
"reviews_count": result.get("reviews"),
"seller": {
"id": result.get("seller_id"),
"name": result.get("seller_name"),
},
"price": {
"value": result.get("primary_offer", {}).get("offer_price"),
"currency": result.get("primary_offer", {}).get("offer_currency"),
},
"link": result.get("product_page_url"),
}
for result in results
]
def get_walmart_last_page_integer(results: dict[str, Any]) -> int:
try:
return int(list(results["pagination"]["other_pages"].keys())[-1])
except (KeyError, IndexError, ValueError):
return 1
def extract_walmart_product_details(product: dict[str, Any]) -> dict[str, Any]:
return {
"item_id": product.get("us_item_id"),
"product_type": product.get("product_type"),
"title": product.get("title"),
"description_html": product.get("short_description_html"),
"rating": product.get("rating"),
"reviews_count": product.get("reviews"),
"seller": {
"id": product.get("seller_id"),
"name": product.get("seller_name"),
},
"manufacturer_name": product.get("manufacturer"),
"price": {
"value": product.get("price_map", {}).get("price"),
"currency": product.get("price_map", {}).get("currency"),
"previous_price": product.get("price_map", {}).get("was_price", {}).get("price"),
},
"link": product.get("product_page_url"),
"variant_options": extract_walmart_variant_options(product.get("variant_swatches", [])),
}
def extract_walmart_variant_options(variant_swatches: list[dict[str, Any]]) -> list[dict[str, Any]]:
variants = []
for variant_swatch in variant_swatches:
variant_name = variant_swatch.get("name")
if not variant_name:
continue
options = []
for selection in variant_swatch.get("available_selections", []):
selection_name = selection.get("name")
if selection_name and selection_name not in options:
options.append(selection_name)
variants.append({variant_name: options})
return variants
# ------------------------------------------------------------------------------------------------
# YouTube utils
# ------------------------------------------------------------------------------------------------