Various Search Toolkits (#285)

1. Add the following tools:
* Google Finance
    - get_stock_summary
    - get_stock_historical_data
* Google Flights
    - search_roundtrip_flights
    - search_one_way_flights
* Google Hotels
    - search_hotels
    
2. Add some common helper functions for serpAPI tools.
This commit is contained in:
Eric Gustin 2025-03-14 13:23:14 -08:00 committed by GitHub
parent ef5b19b4a2
commit 99ff11d30e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 497 additions and 11 deletions

View file

@ -0,0 +1,80 @@
from enum import Enum
class GoogleFinanceWindow(Enum):
ONE_DAY = "1D"
FIVE_DAYS = "5D"
ONE_MONTH = "1M"
SIX_MONTHS = "6M"
YEAR_TO_DATE = "YTD"
ONE_YEAR = "1Y"
FIVE_YEARS = "5Y"
MAX = "MAX"
class GoogleFlightsTravelClass(Enum):
ECONOMY = "ECONOMY"
PREMIUM_ECONOMY = "PREMIUM_ECONOMY"
BUSINESS = "BUSINESS"
FIRST = "FIRST"
def to_api_value(self) -> int:
_map = {
"ECONOMY": 1,
"PREMIUM_ECONOMY": 2,
"BUSINESS": 3,
"FIRST": 4,
}
return _map[self.value]
class GoogleFlightsMaxStops(Enum):
ANY = "ANY"
NONSTOP = "NONSTOP"
ONE = "ONE"
TWO = "TWO"
def to_api_value(self) -> int:
_map = {
"ANY": 0,
"NONSTOP": 1,
"ONE": 2,
"TWO": 3,
}
return _map[self.value]
class GoogleFlightsSortBy(Enum):
TOP_FLIGHTS = "TOP_FLIGHTS"
PRICE = "PRICE"
DEPARTURE_TIME = "DEPARTURE_TIME"
ARRIVAL_TIME = "ARRIVAL_TIME"
DURATION = "DURATION"
EMISSIONS = "EMISSIONS"
def to_api_value(self) -> int:
_map = {
"TOP_FLIGHTS": 1,
"PRICE": 2,
"DEPARTURE_TIME": 3,
"ARRIVAL_TIME": 4,
"DURATION": 5,
"EMISSIONS": 6,
}
return _map[self.value]
class GoogleHotelsSortBy(Enum):
RELEVANCE = "RELEVANCE"
LOWEST_PRICE = "LOWEST_PRICE"
HIGHEST_RATING = "HIGHEST_RATING"
MOST_REVIEWED = "MOST_REVIEWED"
def to_api_value(self) -> int | None:
_map = {
"RELEVANCE": None,
"LOWEST_PRICE": 3,
"HIGHEST_RATING": 8,
"MOST_REVIEWED": 13,
}
return _map[self.value]

View file

@ -0,0 +1,13 @@
from arcade_search.tools.google_finance import get_stock_historical_data, get_stock_summary
from arcade_search.tools.google_flights import search_one_way_flights, search_roundtrip_flights
from arcade_search.tools.google_hotels import search_hotels
from arcade_search.tools.google_search import search_google
__all__ = [
"search_google", # Google Search
"get_stock_summary", # Google Finance
"get_stock_historical_data", # Google Finance
"search_one_way_flights", # Google Flights
"search_roundtrip_flights", # Google Flights
"search_hotels", # Google Hotels
]

View file

@ -0,0 +1,86 @@
from typing import Annotated, Any
from arcade.sdk import ToolContext, tool
from arcade_search.enums import GoogleFinanceWindow
from arcade_search.utils import call_serpapi, prepare_params
@tool(requires_secrets=["SERP_API_KEY"])
async def get_stock_summary(
context: ToolContext,
ticker_symbol: Annotated[
str,
"The stock ticker to get summary for. For example, 'GOOG' is the ticker symbol for Google",
],
exchange_identifier: Annotated[
str,
"The exchange identifier. This part indicates the market where the "
"stock is traded. For example, 'NASDAQ', 'NYSE', 'TSE', 'LSE', etc.",
],
) -> Annotated[dict[str, Any], "Summary of the stock's recent performance"]:
"""Retrieve the summary information for a given stock ticker using the Google Finance API.
Gets the stock's current price as well as price movement from the most recent trading day.
"""
# Prepare the request
query = (
f"{ticker_symbol.upper()}:{exchange_identifier.upper()}"
if exchange_identifier
else ticker_symbol.upper()
)
params = prepare_params("google_finance", q=query)
# Execute the request
results = call_serpapi(context, params)
# Parse the results
summary: dict = results.get("summary", {})
return summary
@tool(requires_secrets=["SERP_API_KEY"])
async def get_stock_historical_data(
context: ToolContext,
ticker_symbol: Annotated[
str,
"The stock ticker to get summary for. For example, 'GOOG' is the ticker symbol for Google",
],
exchange_identifier: Annotated[
str,
"The exchange identifier. This part indicates the market where the "
"stock is traded. For example, 'NASDAQ', 'NYSE', 'TSE', 'LSE', etc.",
],
window: Annotated[
GoogleFinanceWindow, "Time window for the graph data. Defaults to 1 month"
] = GoogleFinanceWindow.ONE_MONTH,
) -> Annotated[
dict[str, Any],
"A stock's price and volume data at a specific time interval over a specified time window",
]:
"""Fetch historical stock price data over a specified time window
Returns a stock's price and volume data over a specified time window
"""
# Prepare the request
query = (
f"{ticker_symbol.upper()}:{exchange_identifier.upper()}"
if exchange_identifier
else ticker_symbol.upper()
)
params = prepare_params("google_finance", q=query, window=window.value)
# Execute the request
results = call_serpapi(context, params)
# Parse the results
data = {
"summary": results.get("summary", {}),
"graph": results.get("graph", []),
}
key_events = results.get("key_events")
if key_events:
data["key_events"] = key_events
return data

View file

@ -0,0 +1,110 @@
from typing import Annotated, Any, Optional
from arcade.sdk import ToolContext, tool
from arcade_search.enums import GoogleFlightsMaxStops, GoogleFlightsSortBy, GoogleFlightsTravelClass
from arcade_search.utils import call_serpapi, parse_flight_results, prepare_params
@tool(requires_secrets=["SERP_API_KEY"])
async def search_roundtrip_flights(
context: ToolContext,
departure_airport_code: Annotated[
str, "The departure airport code. An uppercase 3-letter code"
],
arrival_airport_code: Annotated[str, "The arrival airport code. An uppercase 3-letter code"],
outbound_date: Annotated[str, "Flight outbound date in YYYY-MM-DD format"],
return_date: Annotated[Optional[str], "Flight return date in YYYY-MM-DD format"],
currency_code: Annotated[
Optional[str], "Currency of the returned prices. Defaults to 'USD'"
] = "USD",
travel_class: Annotated[
GoogleFlightsTravelClass,
"Travel class of the flight. Defaults to 'ECONOMY'",
] = GoogleFlightsTravelClass.ECONOMY,
num_adults: Annotated[Optional[int], "Number of adult passengers. Defaults to 1"] = 1,
num_children: Annotated[Optional[int], "Number of child passengers. Defaults to 0"] = 0,
max_stops: Annotated[
GoogleFlightsMaxStops,
"Maximum number of stops (layovers) for the flight. Defaults to any number of stops",
] = GoogleFlightsMaxStops.ANY,
sort_by: Annotated[
GoogleFlightsSortBy,
"The sorting order of the results. Defaults to TOP_FLIGHTS.",
] = GoogleFlightsSortBy.TOP_FLIGHTS,
) -> Annotated[dict[str, Any], "Flight search results from the Google Flights API"]:
"""Retrieve flight search results using Google Flights"""
# Prepare the request
params = prepare_params(
"google_flights",
departure_id=departure_airport_code,
arrival_id=arrival_airport_code,
outbound_date=outbound_date,
return_date=return_date,
currency=currency_code,
travel_class=travel_class.to_api_value(),
adults=num_adults,
children=num_children,
stops=max_stops.to_api_value(),
sort_by=sort_by.to_api_value(),
deep_search=True, # Same search depth of the Google Flights page in the browser
)
# Execute the request
results = call_serpapi(context, params)
# Parse the results
flights = parse_flight_results(results)
return flights
@tool(requires_secrets=["SERP_API_KEY"])
async def search_one_way_flights(
context: ToolContext,
departure_airport_code: Annotated[
str, "The departure airport code. An uppercase 3-letter code"
],
arrival_airport_code: Annotated[str, "The arrival airport code. An uppercase 3-letter code"],
outbound_date: Annotated[str, "Flight departure date in YYYY-MM-DD format"],
currency_code: Annotated[
Optional[str], "Currency of the returned prices. Defaults to 'USD'"
] = "USD",
travel_class: Annotated[
GoogleFlightsTravelClass,
"Travel class of the flight. Defaults to 'ECONOMY'",
] = GoogleFlightsTravelClass.ECONOMY,
num_adults: Annotated[Optional[int], "Number of adult passengers. Defaults to 1"] = 1,
num_children: Annotated[Optional[int], "Number of child passengers. Defaults to 0"] = 0,
max_stops: Annotated[
GoogleFlightsMaxStops,
"Maximum number of stops (layovers) for the flight. Defaults to any number of stops",
] = GoogleFlightsMaxStops.ANY,
sort_by: Annotated[
GoogleFlightsSortBy,
"The sorting order of the results. Defaults to TOP_FLIGHTS.",
] = GoogleFlightsSortBy.TOP_FLIGHTS,
) -> Annotated[dict[str, Any], "Flight search results from the Google Flights API"]:
"""Retrieve flight search results for a one-way flight using Google Flights"""
params = prepare_params(
"google_flights",
departure_id=departure_airport_code,
arrival_id=arrival_airport_code,
outbound_date=outbound_date,
currency=currency_code,
travel_class=travel_class.to_api_value(),
adults=num_adults,
children=num_children,
stops=max_stops.to_api_value(),
sort_by=sort_by.to_api_value(),
type=2, # indicates one-way
deep_search=True, # Same search depth as the Google Flights page in the browser
)
# Execute the request
results = call_serpapi(context, params)
# Parse the results
flights = parse_flight_results(results)
return flights

View file

@ -0,0 +1,58 @@
from typing import Annotated, Any, Optional
from arcade.sdk import ToolContext, tool
from arcade_search.enums import GoogleHotelsSortBy
from arcade_search.utils import call_serpapi, prepare_params
@tool(requires_secrets=["SERP_API_KEY"])
async def search_hotels(
context: ToolContext,
location: Annotated[str, "Location to search for hotels, e.g., a city name, a state, etc."],
check_in_date: Annotated[str, "Check-in date in YYYY-MM-DD format"],
check_out_date: Annotated[str, "Check-out date in YYYY-MM-DD format"],
query: Annotated[
Optional[str], "Anything that would be used in a regular Google Hotels search"
] = None,
currency: Annotated[Optional[str], "Currency code for prices. Defaults to 'USD'"] = "USD",
min_price: Annotated[Optional[int], "Minimum price per night. Defaults to no minimum"] = None,
max_price: Annotated[Optional[int], "Maximum price per night. Defaults to no maximum"] = None,
num_adults: Annotated[Optional[int], "Number of adults per room. Defaults to 2"] = 2,
num_children: Annotated[Optional[int], "Number of children per room. Defaults to 0"] = 0,
sort_by: Annotated[
GoogleHotelsSortBy, "The sorting order of the results. Defaults to RELEVANCE"
] = GoogleHotelsSortBy.RELEVANCE,
num_results: Annotated[
Optional[int], "Maximum number of results to return. Defaults to 5. Max 20"
] = 5,
) -> Annotated[dict[str, Any], "Hotel search results from the Google Hotels API"]:
"""Retrieve hotel search results using the Google Hotels API."""
# Prepare the request
params = prepare_params(
"google_hotels",
q=f"{query}, {location}" if query else location,
check_in_date=check_in_date,
check_out_date=check_out_date,
currency=currency,
min_price=min_price,
max_price=max_price,
adults=num_adults,
children=num_children,
sort_by=sort_by.to_api_value(),
)
# Execute the request
results = call_serpapi(context, params)
# Parse the results
properties = results.get("properties", [])[:num_results]
# Remove unwanted fields from each property
for hotel in properties:
hotel.pop("images", None)
hotel.pop("extracted_hotel_class", None)
hotel.pop("reviews_breakdown", None)
hotel.pop("serpapi_property_details_link", None)
return {"properties": properties}

View file

@ -1,9 +1,10 @@
import json
from typing import Annotated
import serpapi
from arcade.sdk import ToolContext, tool
from arcade_search.utils import call_serpapi, prepare_params
@tool(requires_secrets=["SERP_API_KEY"])
async def search_google(
@ -13,13 +14,8 @@ async def search_google(
) -> str:
"""Search Google using SerpAPI and return organic search results."""
api_key = context.get_secret("SERP_API_KEY")
client = serpapi.Client(api_key=api_key)
params = {"engine": "google", "q": query}
search = client.search(params)
results = search.as_dict()
params = prepare_params("google", q=query)
results = call_serpapi(context, params)
organic_results = results.get("organic_results", [])
return json.dumps(organic_results[:n_results])

View file

@ -0,0 +1,74 @@
import re
from typing import Any
import serpapi
from arcade.sdk import ToolContext
from arcade.sdk.errors import ToolExecutionError
# ------------------------------------------------------------------------------------------------
# General SerpAPI utils
# ------------------------------------------------------------------------------------------------
def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]:
"""
Prepares a parameters dictionary for the SerpAPI call.
Parameters:
engine: The engine name (e.g., "google", "google_finance").
kwargs: Any additional parameters to include.
Returns:
A dictionary containing the base parameters plus any extras,
excluding any parameters whose value is None.
"""
params = {"engine": engine}
params.update({k: v for k, v in kwargs.items() if v is not None})
return params
def call_serpapi(context: ToolContext, params: dict) -> dict:
"""
Execute a search query using the SerpAPI client and return the results as a dictionary.
Args:
context: The tool context containing required secrets.
params: A dictionary of parameters for the SerpAPI search.
Returns:
The search results as a dictionary.
"""
api_key = context.get_secret("SERP_API_KEY")
client = serpapi.Client(api_key=api_key)
try:
search = client.search(params)
return search.as_dict() # type: ignore[no-any-return]
except Exception as e:
# SerpAPI error messages sometimes contain the API key, so we need to sanitize it
sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e))
raise ToolExecutionError(
message="Failed to fetch search results",
developer_message=sanitized_e,
)
# ------------------------------------------------------------------------------------------------
# Google Flights utils
# ------------------------------------------------------------------------------------------------
def parse_flight_results(results: dict[str, Any]) -> dict[str, Any]:
"""Parse the flight results from the Google Flights API
Note: Best flights is not always returned from the API.
"""
flight_data = {}
flights = []
if "best_flights" in results:
flights.extend(results["best_flights"])
if "other_flights" in results:
flights.extend(results["other_flights"])
if "price_insights" in results:
flight_data["price_insights"] = results["price_insights"]
flight_data["flights"] = flights
return flight_data

View file

@ -0,0 +1,13 @@
import pytest
class DummyContext:
def get_secret(self, key: str) -> str | None:
if key.lower() == "serp_api_key":
return "dummy_key"
return None
@pytest.fixture
def dummy_context():
return DummyContext()

View file

@ -9,7 +9,7 @@ from arcade.sdk.eval import (
)
import arcade_search
from arcade_search.tools.google import search_google
from arcade_search.tools import search_google
# Evaluation rubric
rubric = EvalRubric(

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "arcade_search"
version = "1.0.0"
version = "1.1.0"
description = "Tools for searching the web"
authors = ["Arcade <dev@arcade.dev>"]

View file

@ -5,7 +5,7 @@ import pytest
from arcade.core.schema import ToolSecretItem
from arcade.sdk import ToolContext
from arcade_search.tools.google import search_google
from arcade_search.tools import search_google
@pytest.fixture

View file

@ -0,0 +1,56 @@
import pytest
import serpapi
from arcade.sdk.errors import ToolExecutionError
from arcade_search.utils import call_serpapi, prepare_params
@pytest.mark.asyncio
@pytest.mark.parametrize(
"engine, kwargs, expected",
[
("google", {}, {"engine": "google"}),
(
"google",
{"q": "test", "window": 10, "time": "00:12:12"},
{
"engine": "google",
"q": "test",
"window": 10,
"time": "00:12:12",
},
),
],
)
async def test_prepare_params(engine, kwargs, expected):
params = prepare_params(engine, **kwargs)
assert params == expected
@pytest.mark.parametrize(
"error_message, sanitized_message",
[
(
"You hit your rate limit",
"You hit your rate limit",
),
(
"Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=ABC123456",
"Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=***",
),
(
"Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=ABC123456 make sure the api key is correct",
"Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=*** make sure the api key is correct",
),
],
)
def test_call_serpapi_failure(monkeypatch, dummy_context, error_message, sanitized_message):
def fake_serpapi_search(self, params: dict) -> dict:
raise Exception(error_message) # noqa: TRY002
monkeypatch.setattr(serpapi.Client, "search", fake_serpapi_search)
with pytest.raises(ToolExecutionError) as excinfo:
call_serpapi(dummy_context, {})
assert excinfo.value.developer_message == sanitized_message