diff --git a/contrib/langchain/langchain_arcade/_utilities.py b/contrib/langchain/langchain_arcade/_utilities.py index 48c340f0..3c636300 100644 --- a/contrib/langchain/langchain_arcade/_utilities.py +++ b/contrib/langchain/langchain_arcade/_utilities.py @@ -1,7 +1,7 @@ from typing import Any, Callable from arcadepy import NOT_GIVEN, Arcade -from arcadepy.types.shared import ToolDefinition +from arcadepy.types import ToolGetResponse as ToolDefinition from langchain_core.runnables import RunnableConfig from langchain_core.tools import StructuredTool from pydantic import BaseModel, Field, create_model @@ -50,7 +50,7 @@ def tool_definition_to_pydantic_model(tool_def: ToolDefinition) -> type[BaseMode """ try: fields: dict[str, Any] = {} - for param in tool_def.inputs.parameters or []: + for param in tool_def.input.parameters or []: param_type = get_python_type(param.value_schema.val_type) if param_type == list and param.value_schema.inner_val_type: # noqa: E721 inner_type: type[Any] = get_python_type(param.value_schema.inner_val_type) @@ -116,10 +116,7 @@ def create_tool_function( # Authorize the user for the tool auth_response = client.tools.authorize(tool_name=tool_name, user_id=user_id) if auth_response.status != "completed": - auth_message = ( - "Please use the following link to authorize: " - f"{auth_response.authorization_url}" - ) + auth_message = f"Please use the following link to authorize: {auth_response.url}" if langgraph: raise NodeInterrupt(auth_message) return {"error": auth_message} @@ -127,7 +124,7 @@ def create_tool_function( # Execute the tool with provided inputs execute_response = client.tools.execute( tool_name=tool_name, - inputs=kwargs, + input=kwargs, user_id=user_id if user_id is not None else NOT_GIVEN, ) diff --git a/contrib/langchain/langchain_arcade/manager.py b/contrib/langchain/langchain_arcade/manager.py index 7882368e..322b4431 100644 --- a/contrib/langchain/langchain_arcade/manager.py +++ b/contrib/langchain/langchain_arcade/manager.py @@ -3,7 +3,8 @@ from collections.abc import Iterator from typing import Any, Optional from arcadepy import Arcade -from arcadepy.types.shared import AuthorizationResponse, ToolDefinition +from arcadepy.types import ToolGetResponse as ToolDefinition +from arcadepy.types.shared import AuthAuthorizationResponse as AuthorizationResponse from langchain_core.tools import StructuredTool from langchain_arcade._utilities import ( @@ -40,10 +41,16 @@ class ArcadeToolManager: Args: client: Optional Arcade client instance. + **kwargs: Additional keyword arguments to pass to the Arcade client. """ if not client: - api_key = kwargs.get("api_key", os.getenv("ARCADE_API_KEY", None)) - client = Arcade(api_key=api_key) # type: ignore[arg-type] + api_key = kwargs.get("api_key", os.getenv("ARCADE_API_KEY")) + base_url = kwargs.get("base_url", os.getenv("ARCADE_BASE_URL")) + arcade_kwargs = {"api_key": api_key, **kwargs} + if base_url: + arcade_kwargs["base_url"] = base_url + + client = Arcade(**arcade_kwargs) # type: ignore[arg-type] self.client = client self._tools: dict[str, ToolDefinition] = {} @@ -143,7 +150,20 @@ class ArcadeToolManager: >>> manager.init_tools(toolkits=["Search"]) >>> manager.is_authorized("auth_123") """ - return self.client.auth.status(authorization_id=authorization_id).status == "completed" + return self.client.auth.status(id=authorization_id).status == "completed" + + def wait_for_auth(self, authorization_id: str) -> AuthorizationResponse: + """Wait for a tool authorization to complete. + + Example: + >>> manager = ArcadeToolManager(api_key="...") + >>> manager.init_tools(toolkits=["Google.ListEmails"]) + >>> response = manager.authorize("Google.ListEmails", "user_123") + >>> manager.wait_for_auth(response) + >>> # or + >>> manager.wait_for_auth(response.id) + """ + return self.client.auth.wait_for_completion(authorization_id) def requires_auth(self, tool_name: str) -> bool: """Check if a tool requires authorization.""" @@ -162,22 +182,33 @@ class ArcadeToolManager: def _retrieve_tool_definitions( self, tools: Optional[list[str]] = None, toolkits: Optional[list[str]] = None ) -> dict[str, ToolDefinition]: + """Retrieve tool definitions from the Arcade client, accounting for pagination.""" all_tools: list[ToolDefinition] = [] - if tools is not None or toolkits is not None: - if tools: - single_tools = [self.client.tools.get(tool_id=tool_id) for tool_id in tools] - all_tools.extend(single_tools) - if toolkits: - for tk in toolkits: - all_tools.extend(self.client.tools.list(toolkit=tk)) - else: - # retrieve all tools - page_iterator = self.client.tools.list() - all_tools.extend(page_iterator) + # First, gather single tools if the user specifically requested them. + if tools: + for tool_id in tools: + # ToolsResource.get(...) returns a single ToolGetResponse. + single_tool = self.client.tools.get(name=tool_id) + all_tools.append(single_tool) + + # Next, gather tool definitions from any requested toolkits. + if toolkits: + for tk in toolkits: + # tools.list(...) returns a paginated response (SyncOffsetPage), + # so we iterate over its items to accumulate tool definitions. + paginated_tools = self.client.tools.list(toolkit=tk) + all_tools.extend(paginated_tools.items) # type: ignore[arg-type] + + # If no specific tools or toolkits were requested, retrieve *all* tools. + if not tools and not toolkits: + paginated_all_tools = self.client.tools.list() + all_tools.extend(paginated_all_tools.items) # type: ignore[arg-type] + # Build a dictionary that maps the "full_tool_name" to the tool definition. tool_definitions: dict[str, ToolDefinition] = {} - for tool in all_tools: + # For items returned by .list(), the 'toolkit' and 'name' attributes + # should be present as plain fields on the object. (No need to do toolkit.name) full_tool_name = f"{tool.toolkit.name}_{tool.name}" tool_definitions[full_tool_name] = tool diff --git a/contrib/langchain/pyproject.toml b/contrib/langchain/pyproject.toml index 8891579b..d5a7a566 100644 --- a/contrib/langchain/pyproject.toml +++ b/contrib/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-arcade" -version = "0.1.2" +version = "0.2.0" description = "An integration package connecting Arcade AI and LangChain/LangGraph" authors = ["Arcade AI "] readme = "README.md" @@ -10,7 +10,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.10,<3.13" langchain-core = "^0.3.0" -arcadepy = "~0.2.0" +arcadepy = "~1.0.0rc1" langgraph = {version = ">=0.2.32,<0.3.0", optional = true} [tool.poetry.extras] diff --git a/contrib/langchain/tests/conftest.py b/contrib/langchain/tests/conftest.py new file mode 100644 index 00000000..43d0fdb6 --- /dev/null +++ b/contrib/langchain/tests/conftest.py @@ -0,0 +1,33 @@ +import os + +import pytest +from arcadepy import Arcade + + +@pytest.fixture(scope="session") +def arcade_base_url(): + """ + Retrieve the ARCADE_BASE_URL from the environment, falling back to a default + if not found. + """ + return os.getenv("ARCADE_BASE_URL", "http://localhost:9099") + + +@pytest.fixture(scope="session") +def arcade_api_key(): + """ + Retrieve the ARCADE_API_KEY from the environment, falling back to a default + if not found. + """ + return os.getenv("ARCADE_API_KEY", "test_api_key") + + +@pytest.fixture(scope="session") +def arcade_client(arcade_base_url, arcade_api_key): + """ + Creates a single Arcade client instance for use in all tests. + Any method calls on this client can be patched/mocked within the tests. + """ + client = Arcade(api_key=arcade_api_key, base_url=arcade_base_url) + yield client + # Teardown logic would go here if necessary diff --git a/contrib/langchain/tests/test_manager.py b/contrib/langchain/tests/test_manager.py new file mode 100644 index 00000000..a32a9113 --- /dev/null +++ b/contrib/langchain/tests/test_manager.py @@ -0,0 +1,311 @@ +from unittest.mock import MagicMock + +import pytest +from arcadepy.pagination import SyncOffsetPage +from arcadepy.types import ToolGetResponse as ToolDefinition +from arcadepy.types.shared import AuthAuthorizationResponse as AuthorizationResponse +from langchain_arcade.manager import ArcadeToolManager + + +@pytest.fixture +def mock_arcade_client(): + """ + A fixture to mock the Arcade client object for testing the ArcadeToolManager. + + This mocks all relevant methods used by the manager, including: + - tools.get + - tools.list + - tools.authorize + - auth.status + """ + mock_client = MagicMock() + # Mock the "tools" sub-client + mock_client.tools.get = MagicMock() + mock_client.tools.list = MagicMock() + mock_client.tools.authorize = MagicMock() + # Mock the "auth" sub-client + mock_client.auth.status = MagicMock() + + return mock_client + + +@pytest.fixture +def manager(mock_arcade_client): + """ + A fixture that creates an ArcadeToolManager with the mocked Arcade client. + """ + return ArcadeToolManager(client=mock_arcade_client) + + +@pytest.fixture +def make_tool(): + """ + A factory fixture for creating a valid ToolDefinition with a given + fully qualified name. Because the underlying ToolDefinition model + expects "toolkit" to be a dictionary with at least one field (for example "slug"), + and "requirements.authorization" to be a valid dictionary if present, we set them up + accordingly. + """ + + def _make_tool(fully_qualified_name="Search_SearchGoogle", **kwargs): + # Split on the first dot to derive a 'toolkit' slug and a tool 'name' + if "." in fully_qualified_name: + raw_toolkit, raw_tool_name = fully_qualified_name.split(".", 1) + elif "_" in fully_qualified_name: + # Convert from "_" to "." to match the expected format of tool name when + # using Langchain models for LLM inference. + raw_toolkit, raw_tool_name = fully_qualified_name.split("_", 1) + + else: + raw_toolkit, raw_tool_name = fully_qualified_name, fully_qualified_name + + # Provide a default toolkit dict unless one already exists in kwargs + toolkit = kwargs.pop("toolkit", {"name": raw_toolkit}) + + # Provide a default input + # arcadepy.types.ToolGetResponse expects "input" to be a valid structure (dict). + tool_input = kwargs.pop("input", {"parameters": []}) + + # Convert MagicMock-based requirements (with authorization) to an appropriate dict, + # or use what's passed. If none is passed, default to None. + requirements = kwargs.pop("requirements", None) + if requirements is not None and not isinstance(requirements, dict): + # If it's e.g. a MagicMock(authorization="xyz"), convert it to a dict + req_auth = getattr(requirements, "authorization", None) + # If the test expects an authorization presence, represent it as a dict + # that Pydantic can parse + if req_auth is not None: + requirements = {"authorization": {"type": req_auth}} + else: + requirements = {"authorization": None} + + # Provide a default description if none is supplied + description = kwargs.pop("description", "Mock tool for testing") + + # Build the pydantic fields + data = { + "fully_qualified_name": fully_qualified_name, + "name": raw_tool_name, + "toolkit": toolkit, + "input": tool_input, + "description": description, + "requirements": requirements, + } + data.update(kwargs) # merge any extras + + return ToolDefinition(**data) + + return _make_tool + + +def test_init_tools(manager, mock_arcade_client, make_tool): + """ + Test that init_tools clears any existing tools and retrieves new ones + from either an explicit list of tools or an entire toolkit. + """ + # Arrange + mock_tool = make_tool("Search_SearchGoogle") + mock_arcade_client.tools.get.return_value = mock_tool + mock_arcade_client.tools.list.return_value = SyncOffsetPage(items=[mock_tool]) + # Act + manager.init_tools(tools=["Search_SearchGoogle"]) + + # Assert + assert "Search_SearchGoogle" in manager.tools + assert manager._tools["Search_SearchGoogle"] == mock_tool + mock_arcade_client.tools.get.assert_called_once_with(name="Search_SearchGoogle") + + +def test_get_tools_no_init(manager, mock_arcade_client, make_tool): + """ + If get_tools is called without init_tools and no tools are specified, + it should call init_tools internally and fetch all available tools. + """ + # Arrange + mock_tool = make_tool("Search_SearchGoogle") + mock_arcade_client.tools.list.return_value = SyncOffsetPage(items=[mock_tool]) + + # Act + tools = manager.get_tools() # no param means manager auto-inits + + # Assert + assert len(tools) == 1 + assert "Search_SearchGoogle" in manager.tools + assert manager._tools["Search_SearchGoogle"] == mock_tool + mock_arcade_client.tools.list.assert_called_once() + + +def test_get_tools_with_explicit(manager, mock_arcade_client, make_tool): + """ + If tools or toolkits are provided to get_tools, the manager should + retrieve or update the internal _tools dictionary accordingly, + then return them as StructuredTool objects. + """ + # Arrange + mock_tool_google = make_tool("Search_SearchGoogle") + mock_tool_bing = make_tool("Search_SearchBing") + mock_arcade_client.tools.get.side_effect = [mock_tool_google, mock_tool_bing] + + # Act + retrieved_tools = manager.get_tools(tools=["Search_SearchGoogle", "Search_SearchBing"]) + + # Assert + assert len(retrieved_tools) == 2 + assert set(manager.tools) == {"Search_SearchGoogle", "Search_SearchBing"} + mock_arcade_client.tools.get.assert_any_call(name="Search_SearchGoogle") + mock_arcade_client.tools.get.assert_any_call(name="Search_SearchBing") + + +def test_authorize(manager, mock_arcade_client): + """ + Test the authorize method to ensure it calls the Arcade client's + tools.authorize method correctly. + """ + # Arrange + mock_arcade_client.tools.authorize.return_value = AuthorizationResponse( + id="auth_123", status="pending", tool_fully_qualified_name="Search_SearchGoogle" + ) + + # Act + response = manager.authorize(tool_name="Search_SearchGoogle", user_id="user_123") + + # Assert + assert response.id == "auth_123" + assert response.status == "pending" + mock_arcade_client.tools.authorize.assert_called_once_with( + tool_name="Search_SearchGoogle", user_id="user_123" + ) + + +def test_is_authorized(manager, mock_arcade_client): + """ + Test the is_authorized method which checks if authorization + has completed for a given authorization ID. + """ + # Arrange + mock_arcade_client.auth.status.return_value = MagicMock(status="completed") + + # Act + status_result = manager.is_authorized("auth_abc") + + # Assert + assert status_result is True + mock_arcade_client.auth.status.assert_called_once_with(id="auth_abc") + + +def test_requires_auth_true(manager, make_tool): + """ + Test the requires_auth method returning True if + the stored tool definition's requirements contain an authorization entry. + """ + # Arrange + tool_name = "Search_SearchGoogle" + # Pass a MagicMock with 'authorization' to ensure it gets converted + mock_tool_def = make_tool(tool_name, requirements=MagicMock(authorization="some_required_auth")) + manager._tools[tool_name] = mock_tool_def + + # Act + result = manager.requires_auth(tool_name) + + # Assert + assert result is True + + +def test_requires_auth_false(manager, make_tool): + """ + Test the requires_auth method returning False if authorization + is not required in the tool definition. + """ + # Arrange + tool_name = "Search_SearchGoogle" + mock_tool_def = make_tool(tool_name, requirements=MagicMock(authorization=None)) + manager._tools[tool_name] = mock_tool_def + + # Act + result = manager.requires_auth(tool_name) + + # Assert + assert result is False + + +def test_get_tool_definition_existing(manager, make_tool): + """ + Test the internal _get_tool_definition method retrieving + an existing tool definition by name. + """ + # Arrange + tool_name = "Search_SearchGoogle" + mock_tool_def = make_tool(tool_name) + manager._tools[tool_name] = mock_tool_def + + # Act + definition = manager._get_tool_definition(tool_name) + + # Assert + assert definition == mock_tool_def + + +def test_get_tool_definition_missing(manager): + """ + Test the internal _get_tool_definition method raising a ValueError + if the tool is not in the manager. + """ + # Act & Assert + with pytest.raises(ValueError) as excinfo: + manager._get_tool_definition("Nonexistent.Tool") + + assert "Tool 'Nonexistent.Tool' not found" in str(excinfo.value) + + +def test_retrieve_tool_definitions_tools_only(manager, mock_arcade_client, make_tool): + """ + Test the internal _retrieve_tool_definitions method by specifying tools only. + """ + # Arrange + mock_tool = make_tool("Search_SearchGoogle") + mock_arcade_client.tools.get.return_value = mock_tool + + # Act + results = manager._retrieve_tool_definitions(tools=["Search_SearchGoogle"], toolkits=None) + + # Assert + assert len(results) == 1 + assert "Search_SearchGoogle" in results + mock_arcade_client.tools.get.assert_called_once_with(name="Search_SearchGoogle") + + +def test_retrieve_tool_definitions_toolkits_only(manager, mock_arcade_client, make_tool): + """ + Test the internal _retrieve_tool_definitions method by specifying toolkits. + """ + # Arrange + mock_tool = make_tool("Search_SearchBing") + mock_arcade_client.tools.list.return_value = SyncOffsetPage(items=[mock_tool]) + + # Act + results = manager._retrieve_tool_definitions(tools=None, toolkits=["Search"]) + + # Assert + assert len(results) == 1 + assert "Search_SearchBing" in results + mock_arcade_client.tools.list.assert_called_once_with(toolkit="Search") + + +def test_retrieve_tool_definitions_no_args(manager, mock_arcade_client, make_tool): + """ + Test the internal _retrieve_tool_definitions method when no + arguments are provided, retrieving all available tools. + """ + # Arrange + mock_tool1 = make_tool("Search_SearchGoogle") + mock_tool2 = make_tool("Search_SearchBing") + mock_arcade_client.tools.list.return_value = SyncOffsetPage(items=[mock_tool1, mock_tool2]) + + # Act + results = manager._retrieve_tool_definitions() + + # Assert + assert len(results) == 2 + assert "Search_SearchGoogle" in results + assert "Search_SearchBing" in results + mock_arcade_client.tools.list.assert_called_once() diff --git a/contrib/langchain/tox.ini b/contrib/langchain/tox.ini new file mode 100644 index 00000000..f6876a6d --- /dev/null +++ b/contrib/langchain/tox.ini @@ -0,0 +1,10 @@ +[tox] +envlist = py310, py311, py312 + +[testenv] +deps = + pytest + pytest-cov + +commands = + pytest --cov=langchain_arcade --cov-report=term-missing diff --git a/examples/langchain/custom_graph_with_auth.py b/examples/langchain/custom_graph_with_auth.py index a8be23ea..47958b6f 100644 --- a/examples/langchain/custom_graph_with_auth.py +++ b/examples/langchain/custom_graph_with_auth.py @@ -1,5 +1,4 @@ import os -import time # Import necessary classes and modules from langchain_arcade import ArcadeToolManager @@ -15,7 +14,7 @@ openai_api_key = os.environ["OPENAI_API_KEY"] # Initialize the tool manager and fetch tools compatible with langgraph tool_manager = ArcadeToolManager(api_key=arcade_api_key) tools = tool_manager.get_tools( - toolkits=["Github"], + toolkits=["Github", "Google"], langgraph=True, # use langgraph-specific behavior ) tool_node = ToolNode(tools) @@ -25,8 +24,22 @@ model = ChatOpenAI(model="gpt-4o", api_key=openai_api_key) model_with_tools = model.bind_tools(tools) +#### Helpers #### +def get_nth_tool_call(state: MessagesState, n: int = 0): + last_message = state["messages"][-1] + return last_message.tool_calls[n] + + +def has_tool_calls(state: MessagesState): + last_message = state["messages"][-1] + return last_message.tool_calls is not None and len(last_message.tool_calls) > 0 + + +#### Workflow #### + + # Function to invoke the model and get a response -def call_agent(state): +def call_agent(state: MessagesState): messages = state["messages"] response = model_with_tools.invoke(messages) # Return the updated message history @@ -35,9 +48,8 @@ def call_agent(state): # Function to determine the next step in the workflow based on the last message def should_continue(state: MessagesState): - last_message = state["messages"][-1] - if last_message.tool_calls: - tool_name = last_message.tool_calls[0]["name"] + if has_tool_calls(state): + tool_name = get_nth_tool_call(state)["name"] if tool_manager.requires_auth(tool_name): return "authorization" # Proceed to authorization if required else: @@ -48,54 +60,57 @@ def should_continue(state: MessagesState): # Function to handle authorization for tools that require it def authorize(state: MessagesState, config: dict): user_id = config["configurable"].get("user_id") - tool_name = state["messages"][-1].tool_calls[0]["name"] + tool_name = get_nth_tool_call(state)["name"] auth_response = tool_manager.authorize(tool_name, user_id) - if auth_response.status == "completed": - # Authorization completed successfully; continue - return {"messages": state["messages"]} - else: + if auth_response.status != "completed": # Prompt the user to visit the authorization URL - print(f"Visit the following URL to authorize: {auth_response.authorization_url}") - # Wait until authorization is completed - while not tool_manager.is_authorized(auth_response.authorization_id): - time.sleep(1) - return {"messages": state["messages"]} + print(f"Visit the following URL to authorize: {auth_response.url}") + + # wait for the user to complete the authorization + # and then check the authorization status again + tool_manager.wait_for_auth(auth_response.id) + if not tool_manager.is_authorized(auth_response.id): + # node interrupt? + raise ValueError("Authorization failed") + + return {"messages": state["messages"]} -# Build the workflow graph using StateGraph -workflow = StateGraph(MessagesState) +if __name__ == "__main__": + # Build the workflow graph using StateGraph + workflow = StateGraph(MessagesState) -# Add nodes (steps) to the graph -workflow.add_node("agent", call_agent) -workflow.add_node("tools", tool_node) -workflow.add_node("authorization", authorize) + # Add nodes (steps) to the graph + workflow.add_node("agent", call_agent) + workflow.add_node("tools", tool_node) + workflow.add_node("authorization", authorize) -# Define the edges and control flow between nodes -workflow.add_edge(START, "agent") -workflow.add_conditional_edges("agent", should_continue, ["authorization", "tools", END]) -workflow.add_edge("authorization", "tools") -workflow.add_edge("tools", "agent") + # Define the edges and control flow between nodes + workflow.add_edge(START, "agent") + workflow.add_conditional_edges("agent", should_continue, ["authorization", "tools", END]) + workflow.add_edge("authorization", "tools") + workflow.add_edge("tools", "agent") -# Set up memory for checkpointing the state -memory = MemorySaver() + # Set up memory for checkpointing the state + memory = MemorySaver() -# Compile the graph with the checkpointer -graph = workflow.compile(checkpointer=memory) + # Compile the graph with the checkpointer + graph = workflow.compile(checkpointer=memory) -# Define the input messages from the user -inputs = { - "messages": [HumanMessage(content="Star arcadeai/arcade-ai on GitHub!")], -} - -# Configuration with thread and user IDs for authorization purposes -config = { - "configurable": { - "thread_id": "4", - "user_id": "user@example.com", + # Define the input messages from the user + inputs = { + "messages": [HumanMessage(content="what's on my calendar today?")], } -} -# Run the graph and stream the outputs -for chunk in graph.stream(inputs, config=config, stream_mode="values"): - # Pretty-print the last message in the chunk - chunk["messages"][-1].pretty_print() + # Configuration with thread and user IDs for authorization purposes + config = { + "configurable": { + "thread_id": "4", + "user_id": "user@example.comd", + } + } + + # Run the graph and stream the outputs + for chunk in graph.stream(inputs, config=config, stream_mode="values"): + # Pretty-print the last message in the chunk + chunk["messages"][-1].pretty_print() diff --git a/examples/langchain/requirements.txt b/examples/langchain/requirements.txt index d445ac7b..f006e011 100644 --- a/examples/langchain/requirements.txt +++ b/examples/langchain/requirements.txt @@ -1,5 +1,3 @@ -langchain>=0.3.0 -arcadepy>=0.1.2 langchain-google-community[gmail]>=0.1.1 langchain-openai>=0.1.1 -langgraph>=0.1.1 +langchain-arcade[langgraph]>=0.2.0 diff --git a/examples/langchain/simple_graph.py b/examples/langchain/simple_graph.py index f1824576..12204354 100644 --- a/examples/langchain/simple_graph.py +++ b/examples/langchain/simple_graph.py @@ -39,4 +39,5 @@ for chunk in graph.stream(inputs, stream_mode="values", config=config): # Access the latest message from the conversation last_message = chunk["messages"][-1] # Print the assistant's message content - print(last_message.content) + if last_message.content: + print(last_message.content)