diff --git a/toolkits/search/arcade_search/enums.py b/toolkits/search/arcade_search/enums.py new file mode 100644 index 00000000..8e43f80e --- /dev/null +++ b/toolkits/search/arcade_search/enums.py @@ -0,0 +1,80 @@ +from enum import Enum + + +class GoogleFinanceWindow(Enum): + ONE_DAY = "1D" + FIVE_DAYS = "5D" + ONE_MONTH = "1M" + SIX_MONTHS = "6M" + YEAR_TO_DATE = "YTD" + ONE_YEAR = "1Y" + FIVE_YEARS = "5Y" + MAX = "MAX" + + +class GoogleFlightsTravelClass(Enum): + ECONOMY = "ECONOMY" + PREMIUM_ECONOMY = "PREMIUM_ECONOMY" + BUSINESS = "BUSINESS" + FIRST = "FIRST" + + def to_api_value(self) -> int: + _map = { + "ECONOMY": 1, + "PREMIUM_ECONOMY": 2, + "BUSINESS": 3, + "FIRST": 4, + } + return _map[self.value] + + +class GoogleFlightsMaxStops(Enum): + ANY = "ANY" + NONSTOP = "NONSTOP" + ONE = "ONE" + TWO = "TWO" + + def to_api_value(self) -> int: + _map = { + "ANY": 0, + "NONSTOP": 1, + "ONE": 2, + "TWO": 3, + } + return _map[self.value] + + +class GoogleFlightsSortBy(Enum): + TOP_FLIGHTS = "TOP_FLIGHTS" + PRICE = "PRICE" + DEPARTURE_TIME = "DEPARTURE_TIME" + ARRIVAL_TIME = "ARRIVAL_TIME" + DURATION = "DURATION" + EMISSIONS = "EMISSIONS" + + def to_api_value(self) -> int: + _map = { + "TOP_FLIGHTS": 1, + "PRICE": 2, + "DEPARTURE_TIME": 3, + "ARRIVAL_TIME": 4, + "DURATION": 5, + "EMISSIONS": 6, + } + return _map[self.value] + + +class GoogleHotelsSortBy(Enum): + RELEVANCE = "RELEVANCE" + LOWEST_PRICE = "LOWEST_PRICE" + HIGHEST_RATING = "HIGHEST_RATING" + MOST_REVIEWED = "MOST_REVIEWED" + + def to_api_value(self) -> int | None: + _map = { + "RELEVANCE": None, + "LOWEST_PRICE": 3, + "HIGHEST_RATING": 8, + "MOST_REVIEWED": 13, + } + return _map[self.value] diff --git a/toolkits/search/arcade_search/tools/__init__.py b/toolkits/search/arcade_search/tools/__init__.py index e69de29b..c99428f1 100644 --- a/toolkits/search/arcade_search/tools/__init__.py +++ b/toolkits/search/arcade_search/tools/__init__.py @@ -0,0 +1,13 @@ +from arcade_search.tools.google_finance import get_stock_historical_data, get_stock_summary +from arcade_search.tools.google_flights import search_one_way_flights, search_roundtrip_flights +from arcade_search.tools.google_hotels import search_hotels +from arcade_search.tools.google_search import search_google + +__all__ = [ + "search_google", # Google Search + "get_stock_summary", # Google Finance + "get_stock_historical_data", # Google Finance + "search_one_way_flights", # Google Flights + "search_roundtrip_flights", # Google Flights + "search_hotels", # Google Hotels +] diff --git a/toolkits/search/arcade_search/tools/google_finance.py b/toolkits/search/arcade_search/tools/google_finance.py new file mode 100644 index 00000000..f2576fcf --- /dev/null +++ b/toolkits/search/arcade_search/tools/google_finance.py @@ -0,0 +1,86 @@ +from typing import Annotated, Any + +from arcade.sdk import ToolContext, tool + +from arcade_search.enums import GoogleFinanceWindow +from arcade_search.utils import call_serpapi, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_stock_summary( + context: ToolContext, + ticker_symbol: Annotated[ + str, + "The stock ticker to get summary for. For example, 'GOOG' is the ticker symbol for Google", + ], + exchange_identifier: Annotated[ + str, + "The exchange identifier. This part indicates the market where the " + "stock is traded. For example, 'NASDAQ', 'NYSE', 'TSE', 'LSE', etc.", + ], +) -> Annotated[dict[str, Any], "Summary of the stock's recent performance"]: + """Retrieve the summary information for a given stock ticker using the Google Finance API. + + Gets the stock's current price as well as price movement from the most recent trading day. + """ + # Prepare the request + query = ( + f"{ticker_symbol.upper()}:{exchange_identifier.upper()}" + if exchange_identifier + else ticker_symbol.upper() + ) + params = prepare_params("google_finance", q=query) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + summary: dict = results.get("summary", {}) + + return summary + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def get_stock_historical_data( + context: ToolContext, + ticker_symbol: Annotated[ + str, + "The stock ticker to get summary for. For example, 'GOOG' is the ticker symbol for Google", + ], + exchange_identifier: Annotated[ + str, + "The exchange identifier. This part indicates the market where the " + "stock is traded. For example, 'NASDAQ', 'NYSE', 'TSE', 'LSE', etc.", + ], + window: Annotated[ + GoogleFinanceWindow, "Time window for the graph data. Defaults to 1 month" + ] = GoogleFinanceWindow.ONE_MONTH, +) -> Annotated[ + dict[str, Any], + "A stock's price and volume data at a specific time interval over a specified time window", +]: + """Fetch historical stock price data over a specified time window + + Returns a stock's price and volume data over a specified time window + """ + # Prepare the request + query = ( + f"{ticker_symbol.upper()}:{exchange_identifier.upper()}" + if exchange_identifier + else ticker_symbol.upper() + ) + params = prepare_params("google_finance", q=query, window=window.value) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + data = { + "summary": results.get("summary", {}), + "graph": results.get("graph", []), + } + key_events = results.get("key_events") + if key_events: + data["key_events"] = key_events + + return data diff --git a/toolkits/search/arcade_search/tools/google_flights.py b/toolkits/search/arcade_search/tools/google_flights.py new file mode 100644 index 00000000..e989746a --- /dev/null +++ b/toolkits/search/arcade_search/tools/google_flights.py @@ -0,0 +1,110 @@ +from typing import Annotated, Any, Optional + +from arcade.sdk import ToolContext, tool + +from arcade_search.enums import GoogleFlightsMaxStops, GoogleFlightsSortBy, GoogleFlightsTravelClass +from arcade_search.utils import call_serpapi, parse_flight_results, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_roundtrip_flights( + context: ToolContext, + departure_airport_code: Annotated[ + str, "The departure airport code. An uppercase 3-letter code" + ], + arrival_airport_code: Annotated[str, "The arrival airport code. An uppercase 3-letter code"], + outbound_date: Annotated[str, "Flight outbound date in YYYY-MM-DD format"], + return_date: Annotated[Optional[str], "Flight return date in YYYY-MM-DD format"], + currency_code: Annotated[ + Optional[str], "Currency of the returned prices. Defaults to 'USD'" + ] = "USD", + travel_class: Annotated[ + GoogleFlightsTravelClass, + "Travel class of the flight. Defaults to 'ECONOMY'", + ] = GoogleFlightsTravelClass.ECONOMY, + num_adults: Annotated[Optional[int], "Number of adult passengers. Defaults to 1"] = 1, + num_children: Annotated[Optional[int], "Number of child passengers. Defaults to 0"] = 0, + max_stops: Annotated[ + GoogleFlightsMaxStops, + "Maximum number of stops (layovers) for the flight. Defaults to any number of stops", + ] = GoogleFlightsMaxStops.ANY, + sort_by: Annotated[ + GoogleFlightsSortBy, + "The sorting order of the results. Defaults to TOP_FLIGHTS.", + ] = GoogleFlightsSortBy.TOP_FLIGHTS, +) -> Annotated[dict[str, Any], "Flight search results from the Google Flights API"]: + """Retrieve flight search results using Google Flights""" + # Prepare the request + params = prepare_params( + "google_flights", + departure_id=departure_airport_code, + arrival_id=arrival_airport_code, + outbound_date=outbound_date, + return_date=return_date, + currency=currency_code, + travel_class=travel_class.to_api_value(), + adults=num_adults, + children=num_children, + stops=max_stops.to_api_value(), + sort_by=sort_by.to_api_value(), + deep_search=True, # Same search depth of the Google Flights page in the browser + ) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + flights = parse_flight_results(results) + + return flights + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_one_way_flights( + context: ToolContext, + departure_airport_code: Annotated[ + str, "The departure airport code. An uppercase 3-letter code" + ], + arrival_airport_code: Annotated[str, "The arrival airport code. An uppercase 3-letter code"], + outbound_date: Annotated[str, "Flight departure date in YYYY-MM-DD format"], + currency_code: Annotated[ + Optional[str], "Currency of the returned prices. Defaults to 'USD'" + ] = "USD", + travel_class: Annotated[ + GoogleFlightsTravelClass, + "Travel class of the flight. Defaults to 'ECONOMY'", + ] = GoogleFlightsTravelClass.ECONOMY, + num_adults: Annotated[Optional[int], "Number of adult passengers. Defaults to 1"] = 1, + num_children: Annotated[Optional[int], "Number of child passengers. Defaults to 0"] = 0, + max_stops: Annotated[ + GoogleFlightsMaxStops, + "Maximum number of stops (layovers) for the flight. Defaults to any number of stops", + ] = GoogleFlightsMaxStops.ANY, + sort_by: Annotated[ + GoogleFlightsSortBy, + "The sorting order of the results. Defaults to TOP_FLIGHTS.", + ] = GoogleFlightsSortBy.TOP_FLIGHTS, +) -> Annotated[dict[str, Any], "Flight search results from the Google Flights API"]: + """Retrieve flight search results for a one-way flight using Google Flights""" + params = prepare_params( + "google_flights", + departure_id=departure_airport_code, + arrival_id=arrival_airport_code, + outbound_date=outbound_date, + currency=currency_code, + travel_class=travel_class.to_api_value(), + adults=num_adults, + children=num_children, + stops=max_stops.to_api_value(), + sort_by=sort_by.to_api_value(), + type=2, # indicates one-way + deep_search=True, # Same search depth as the Google Flights page in the browser + ) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + flights = parse_flight_results(results) + + return flights diff --git a/toolkits/search/arcade_search/tools/google_hotels.py b/toolkits/search/arcade_search/tools/google_hotels.py new file mode 100644 index 00000000..f943c8af --- /dev/null +++ b/toolkits/search/arcade_search/tools/google_hotels.py @@ -0,0 +1,58 @@ +from typing import Annotated, Any, Optional + +from arcade.sdk import ToolContext, tool + +from arcade_search.enums import GoogleHotelsSortBy +from arcade_search.utils import call_serpapi, prepare_params + + +@tool(requires_secrets=["SERP_API_KEY"]) +async def search_hotels( + context: ToolContext, + location: Annotated[str, "Location to search for hotels, e.g., a city name, a state, etc."], + check_in_date: Annotated[str, "Check-in date in YYYY-MM-DD format"], + check_out_date: Annotated[str, "Check-out date in YYYY-MM-DD format"], + query: Annotated[ + Optional[str], "Anything that would be used in a regular Google Hotels search" + ] = None, + currency: Annotated[Optional[str], "Currency code for prices. Defaults to 'USD'"] = "USD", + min_price: Annotated[Optional[int], "Minimum price per night. Defaults to no minimum"] = None, + max_price: Annotated[Optional[int], "Maximum price per night. Defaults to no maximum"] = None, + num_adults: Annotated[Optional[int], "Number of adults per room. Defaults to 2"] = 2, + num_children: Annotated[Optional[int], "Number of children per room. Defaults to 0"] = 0, + sort_by: Annotated[ + GoogleHotelsSortBy, "The sorting order of the results. Defaults to RELEVANCE" + ] = GoogleHotelsSortBy.RELEVANCE, + num_results: Annotated[ + Optional[int], "Maximum number of results to return. Defaults to 5. Max 20" + ] = 5, +) -> Annotated[dict[str, Any], "Hotel search results from the Google Hotels API"]: + """Retrieve hotel search results using the Google Hotels API.""" + # Prepare the request + params = prepare_params( + "google_hotels", + q=f"{query}, {location}" if query else location, + check_in_date=check_in_date, + check_out_date=check_out_date, + currency=currency, + min_price=min_price, + max_price=max_price, + adults=num_adults, + children=num_children, + sort_by=sort_by.to_api_value(), + ) + + # Execute the request + results = call_serpapi(context, params) + + # Parse the results + properties = results.get("properties", [])[:num_results] + + # Remove unwanted fields from each property + for hotel in properties: + hotel.pop("images", None) + hotel.pop("extracted_hotel_class", None) + hotel.pop("reviews_breakdown", None) + hotel.pop("serpapi_property_details_link", None) + + return {"properties": properties} diff --git a/toolkits/search/arcade_search/tools/google.py b/toolkits/search/arcade_search/tools/google_search.py similarity index 68% rename from toolkits/search/arcade_search/tools/google.py rename to toolkits/search/arcade_search/tools/google_search.py index 6ff35207..395355cb 100644 --- a/toolkits/search/arcade_search/tools/google.py +++ b/toolkits/search/arcade_search/tools/google_search.py @@ -1,9 +1,10 @@ import json from typing import Annotated -import serpapi from arcade.sdk import ToolContext, tool +from arcade_search.utils import call_serpapi, prepare_params + @tool(requires_secrets=["SERP_API_KEY"]) async def search_google( @@ -13,13 +14,8 @@ async def search_google( ) -> str: """Search Google using SerpAPI and return organic search results.""" - api_key = context.get_secret("SERP_API_KEY") - - client = serpapi.Client(api_key=api_key) - params = {"engine": "google", "q": query} - - search = client.search(params) - results = search.as_dict() + params = prepare_params("google", q=query) + results = call_serpapi(context, params) organic_results = results.get("organic_results", []) return json.dumps(organic_results[:n_results]) diff --git a/toolkits/search/arcade_search/utils.py b/toolkits/search/arcade_search/utils.py new file mode 100644 index 00000000..7a559346 --- /dev/null +++ b/toolkits/search/arcade_search/utils.py @@ -0,0 +1,74 @@ +import re +from typing import Any + +import serpapi +from arcade.sdk import ToolContext +from arcade.sdk.errors import ToolExecutionError + + +# ------------------------------------------------------------------------------------------------ +# General SerpAPI utils +# ------------------------------------------------------------------------------------------------ +def prepare_params(engine: str, **kwargs: Any) -> dict[str, Any]: + """ + Prepares a parameters dictionary for the SerpAPI call. + + Parameters: + engine: The engine name (e.g., "google", "google_finance"). + kwargs: Any additional parameters to include. + + Returns: + A dictionary containing the base parameters plus any extras, + excluding any parameters whose value is None. + """ + params = {"engine": engine} + params.update({k: v for k, v in kwargs.items() if v is not None}) + return params + + +def call_serpapi(context: ToolContext, params: dict) -> dict: + """ + Execute a search query using the SerpAPI client and return the results as a dictionary. + + Args: + context: The tool context containing required secrets. + params: A dictionary of parameters for the SerpAPI search. + + Returns: + The search results as a dictionary. + """ + api_key = context.get_secret("SERP_API_KEY") + client = serpapi.Client(api_key=api_key) + try: + search = client.search(params) + return search.as_dict() # type: ignore[no-any-return] + except Exception as e: + # SerpAPI error messages sometimes contain the API key, so we need to sanitize it + sanitized_e = re.sub(r"(api_key=)[^ &]+", r"\1***", str(e)) + raise ToolExecutionError( + message="Failed to fetch search results", + developer_message=sanitized_e, + ) + + +# ------------------------------------------------------------------------------------------------ +# Google Flights utils +# ------------------------------------------------------------------------------------------------ +def parse_flight_results(results: dict[str, Any]) -> dict[str, Any]: + """Parse the flight results from the Google Flights API + + Note: Best flights is not always returned from the API. + """ + flight_data = {} + flights = [] + + if "best_flights" in results: + flights.extend(results["best_flights"]) + if "other_flights" in results: + flights.extend(results["other_flights"]) + if "price_insights" in results: + flight_data["price_insights"] = results["price_insights"] + + flight_data["flights"] = flights + + return flight_data diff --git a/toolkits/search/conftest.py b/toolkits/search/conftest.py new file mode 100644 index 00000000..2903cc30 --- /dev/null +++ b/toolkits/search/conftest.py @@ -0,0 +1,13 @@ +import pytest + + +class DummyContext: + def get_secret(self, key: str) -> str | None: + if key.lower() == "serp_api_key": + return "dummy_key" + return None + + +@pytest.fixture +def dummy_context(): + return DummyContext() diff --git a/toolkits/search/evals/eval_google_search.py b/toolkits/search/evals/eval_google_search.py index 0e95799e..2ccc3750 100644 --- a/toolkits/search/evals/eval_google_search.py +++ b/toolkits/search/evals/eval_google_search.py @@ -9,7 +9,7 @@ from arcade.sdk.eval import ( ) import arcade_search -from arcade_search.tools.google import search_google +from arcade_search.tools import search_google # Evaluation rubric rubric = EvalRubric( diff --git a/toolkits/search/pyproject.toml b/toolkits/search/pyproject.toml index 0503c213..2aa50cc1 100644 --- a/toolkits/search/pyproject.toml +++ b/toolkits/search/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "arcade_search" -version = "1.0.0" +version = "1.1.0" description = "Tools for searching the web" authors = ["Arcade "] diff --git a/toolkits/search/tests/test_google.py b/toolkits/search/tests/test_google_search.py similarity index 96% rename from toolkits/search/tests/test_google.py rename to toolkits/search/tests/test_google_search.py index 619be7dd..95ce94a7 100644 --- a/toolkits/search/tests/test_google.py +++ b/toolkits/search/tests/test_google_search.py @@ -5,7 +5,7 @@ import pytest from arcade.core.schema import ToolSecretItem from arcade.sdk import ToolContext -from arcade_search.tools.google import search_google +from arcade_search.tools import search_google @pytest.fixture diff --git a/toolkits/search/tests/test_utils.py b/toolkits/search/tests/test_utils.py new file mode 100644 index 00000000..8ffb68a6 --- /dev/null +++ b/toolkits/search/tests/test_utils.py @@ -0,0 +1,56 @@ +import pytest +import serpapi +from arcade.sdk.errors import ToolExecutionError + +from arcade_search.utils import call_serpapi, prepare_params + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "engine, kwargs, expected", + [ + ("google", {}, {"engine": "google"}), + ( + "google", + {"q": "test", "window": 10, "time": "00:12:12"}, + { + "engine": "google", + "q": "test", + "window": 10, + "time": "00:12:12", + }, + ), + ], +) +async def test_prepare_params(engine, kwargs, expected): + params = prepare_params(engine, **kwargs) + assert params == expected + + +@pytest.mark.parametrize( + "error_message, sanitized_message", + [ + ( + "You hit your rate limit", + "You hit your rate limit", + ), + ( + "Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=ABC123456", + "Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=***", + ), + ( + "Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=ABC123456 make sure the api key is correct", + "Bad Request for url: https://serpapi.com/search?engine=google_hotels&api_key=*** make sure the api key is correct", + ), + ], +) +def test_call_serpapi_failure(monkeypatch, dummy_context, error_message, sanitized_message): + def fake_serpapi_search(self, params: dict) -> dict: + raise Exception(error_message) # noqa: TRY002 + + monkeypatch.setattr(serpapi.Client, "search", fake_serpapi_search) + + with pytest.raises(ToolExecutionError) as excinfo: + call_serpapi(dummy_context, {}) + + assert excinfo.value.developer_message == sanitized_message