Update langchain integration to 0.2.0 (#213)
**PR Description** This update bumps the integration’s version to `0.2.0` and brings several important changes to how `langchain-arcade` interfaces with Arcade tools: 1. **Updated Tool Definition Imports** • Replaces `arcadepy.types.shared.ToolDefinition` with `arcadepy.types.ToolGetResponse as ToolDefinition`. • The parameter extraction is now done via `tool_def.input.parameters` instead of the previous `tool_def.inputs.parameters`. 2. **Authorization Flow Adjustments** • Uses `auth_response.url` instead of `auth_response.authorization_url`. • The `authorize` and `is_authorized` methods now rely on the Arcade client’s updated arguments (`client.auth.status(id=authorization_id)`). 3. **Tool Execution Parameter Renaming** • The `execute` method now expects `input=kwargs` instead of `inputs=kwargs`, aligning with Arcade’s new API spec. 4. **Tool Retrieval Enhancements** • `_retrieve_tool_definitions` is revised to better handle pagination and tool listing (including when no tools/toolkits are explicitly provided). 5. **Version & Dependency Updates** • Increases `langchain-arcade` to `0.2.0`. • Switches `arcadepy` dependency to `~1.0.0rc1`. • Updates example requirements to consume `langchain-arcade[langgraph]>=0.2.0`. These changes may affect existing code that relies on older parameter names (`inputs.parameters` → `input.parameters`) and the renamed execute argument. Please ensure any integrations or custom usage of Arcade tools is updated accordingly.
This commit is contained in:
parent
adaa0da649
commit
6d8e943c96
9 changed files with 471 additions and 75 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Callable
|
||||
|
||||
from arcadepy import NOT_GIVEN, Arcade
|
||||
from arcadepy.types.shared import ToolDefinition
|
||||
from arcadepy.types import ToolGetResponse as ToolDefinition
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
|
@ -50,7 +50,7 @@ def tool_definition_to_pydantic_model(tool_def: ToolDefinition) -> type[BaseMode
|
|||
"""
|
||||
try:
|
||||
fields: dict[str, Any] = {}
|
||||
for param in tool_def.inputs.parameters or []:
|
||||
for param in tool_def.input.parameters or []:
|
||||
param_type = get_python_type(param.value_schema.val_type)
|
||||
if param_type == list and param.value_schema.inner_val_type: # noqa: E721
|
||||
inner_type: type[Any] = get_python_type(param.value_schema.inner_val_type)
|
||||
|
|
@ -116,10 +116,7 @@ def create_tool_function(
|
|||
# Authorize the user for the tool
|
||||
auth_response = client.tools.authorize(tool_name=tool_name, user_id=user_id)
|
||||
if auth_response.status != "completed":
|
||||
auth_message = (
|
||||
"Please use the following link to authorize: "
|
||||
f"{auth_response.authorization_url}"
|
||||
)
|
||||
auth_message = f"Please use the following link to authorize: {auth_response.url}"
|
||||
if langgraph:
|
||||
raise NodeInterrupt(auth_message)
|
||||
return {"error": auth_message}
|
||||
|
|
@ -127,7 +124,7 @@ def create_tool_function(
|
|||
# Execute the tool with provided inputs
|
||||
execute_response = client.tools.execute(
|
||||
tool_name=tool_name,
|
||||
inputs=kwargs,
|
||||
input=kwargs,
|
||||
user_id=user_id if user_id is not None else NOT_GIVEN,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from collections.abc import Iterator
|
|||
from typing import Any, Optional
|
||||
|
||||
from arcadepy import Arcade
|
||||
from arcadepy.types.shared import AuthorizationResponse, ToolDefinition
|
||||
from arcadepy.types import ToolGetResponse as ToolDefinition
|
||||
from arcadepy.types.shared import AuthAuthorizationResponse as AuthorizationResponse
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from langchain_arcade._utilities import (
|
||||
|
|
@ -40,10 +41,16 @@ class ArcadeToolManager:
|
|||
|
||||
Args:
|
||||
client: Optional Arcade client instance.
|
||||
**kwargs: Additional keyword arguments to pass to the Arcade client.
|
||||
"""
|
||||
if not client:
|
||||
api_key = kwargs.get("api_key", os.getenv("ARCADE_API_KEY", None))
|
||||
client = Arcade(api_key=api_key) # type: ignore[arg-type]
|
||||
api_key = kwargs.get("api_key", os.getenv("ARCADE_API_KEY"))
|
||||
base_url = kwargs.get("base_url", os.getenv("ARCADE_BASE_URL"))
|
||||
arcade_kwargs = {"api_key": api_key, **kwargs}
|
||||
if base_url:
|
||||
arcade_kwargs["base_url"] = base_url
|
||||
|
||||
client = Arcade(**arcade_kwargs) # type: ignore[arg-type]
|
||||
self.client = client
|
||||
self._tools: dict[str, ToolDefinition] = {}
|
||||
|
||||
|
|
@ -143,7 +150,20 @@ class ArcadeToolManager:
|
|||
>>> manager.init_tools(toolkits=["Search"])
|
||||
>>> manager.is_authorized("auth_123")
|
||||
"""
|
||||
return self.client.auth.status(authorization_id=authorization_id).status == "completed"
|
||||
return self.client.auth.status(id=authorization_id).status == "completed"
|
||||
|
||||
def wait_for_auth(self, authorization_id: str) -> AuthorizationResponse:
|
||||
"""Wait for a tool authorization to complete.
|
||||
|
||||
Example:
|
||||
>>> manager = ArcadeToolManager(api_key="...")
|
||||
>>> manager.init_tools(toolkits=["Google.ListEmails"])
|
||||
>>> response = manager.authorize("Google.ListEmails", "user_123")
|
||||
>>> manager.wait_for_auth(response)
|
||||
>>> # or
|
||||
>>> manager.wait_for_auth(response.id)
|
||||
"""
|
||||
return self.client.auth.wait_for_completion(authorization_id)
|
||||
|
||||
def requires_auth(self, tool_name: str) -> bool:
|
||||
"""Check if a tool requires authorization."""
|
||||
|
|
@ -162,22 +182,33 @@ class ArcadeToolManager:
|
|||
def _retrieve_tool_definitions(
|
||||
self, tools: Optional[list[str]] = None, toolkits: Optional[list[str]] = None
|
||||
) -> dict[str, ToolDefinition]:
|
||||
"""Retrieve tool definitions from the Arcade client, accounting for pagination."""
|
||||
all_tools: list[ToolDefinition] = []
|
||||
if tools is not None or toolkits is not None:
|
||||
if tools:
|
||||
single_tools = [self.client.tools.get(tool_id=tool_id) for tool_id in tools]
|
||||
all_tools.extend(single_tools)
|
||||
if toolkits:
|
||||
for tk in toolkits:
|
||||
all_tools.extend(self.client.tools.list(toolkit=tk))
|
||||
else:
|
||||
# retrieve all tools
|
||||
page_iterator = self.client.tools.list()
|
||||
all_tools.extend(page_iterator)
|
||||
|
||||
# First, gather single tools if the user specifically requested them.
|
||||
if tools:
|
||||
for tool_id in tools:
|
||||
# ToolsResource.get(...) returns a single ToolGetResponse.
|
||||
single_tool = self.client.tools.get(name=tool_id)
|
||||
all_tools.append(single_tool)
|
||||
|
||||
# Next, gather tool definitions from any requested toolkits.
|
||||
if toolkits:
|
||||
for tk in toolkits:
|
||||
# tools.list(...) returns a paginated response (SyncOffsetPage),
|
||||
# so we iterate over its items to accumulate tool definitions.
|
||||
paginated_tools = self.client.tools.list(toolkit=tk)
|
||||
all_tools.extend(paginated_tools.items) # type: ignore[arg-type]
|
||||
|
||||
# If no specific tools or toolkits were requested, retrieve *all* tools.
|
||||
if not tools and not toolkits:
|
||||
paginated_all_tools = self.client.tools.list()
|
||||
all_tools.extend(paginated_all_tools.items) # type: ignore[arg-type]
|
||||
# Build a dictionary that maps the "full_tool_name" to the tool definition.
|
||||
tool_definitions: dict[str, ToolDefinition] = {}
|
||||
|
||||
for tool in all_tools:
|
||||
# For items returned by .list(), the 'toolkit' and 'name' attributes
|
||||
# should be present as plain fields on the object. (No need to do toolkit.name)
|
||||
full_tool_name = f"{tool.toolkit.name}_{tool.name}"
|
||||
tool_definitions[full_tool_name] = tool
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "langchain-arcade"
|
||||
version = "0.1.2"
|
||||
version = "0.2.0"
|
||||
description = "An integration package connecting Arcade AI and LangChain/LangGraph"
|
||||
authors = ["Arcade AI <dev@arcade-ai.com>"]
|
||||
readme = "README.md"
|
||||
|
|
@ -10,7 +10,7 @@ license = "MIT"
|
|||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.13"
|
||||
langchain-core = "^0.3.0"
|
||||
arcadepy = "~0.2.0"
|
||||
arcadepy = "~1.0.0rc1"
|
||||
langgraph = {version = ">=0.2.32,<0.3.0", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
|||
33
contrib/langchain/tests/conftest.py
Normal file
33
contrib/langchain/tests/conftest.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from arcadepy import Arcade
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def arcade_base_url():
|
||||
"""
|
||||
Retrieve the ARCADE_BASE_URL from the environment, falling back to a default
|
||||
if not found.
|
||||
"""
|
||||
return os.getenv("ARCADE_BASE_URL", "http://localhost:9099")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def arcade_api_key():
|
||||
"""
|
||||
Retrieve the ARCADE_API_KEY from the environment, falling back to a default
|
||||
if not found.
|
||||
"""
|
||||
return os.getenv("ARCADE_API_KEY", "test_api_key")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def arcade_client(arcade_base_url, arcade_api_key):
|
||||
"""
|
||||
Creates a single Arcade client instance for use in all tests.
|
||||
Any method calls on this client can be patched/mocked within the tests.
|
||||
"""
|
||||
client = Arcade(api_key=arcade_api_key, base_url=arcade_base_url)
|
||||
yield client
|
||||
# Teardown logic would go here if necessary
|
||||
311
contrib/langchain/tests/test_manager.py
Normal file
311
contrib/langchain/tests/test_manager.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from arcadepy.pagination import SyncOffsetPage
|
||||
from arcadepy.types import ToolGetResponse as ToolDefinition
|
||||
from arcadepy.types.shared import AuthAuthorizationResponse as AuthorizationResponse
|
||||
from langchain_arcade.manager import ArcadeToolManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_arcade_client():
|
||||
"""
|
||||
A fixture to mock the Arcade client object for testing the ArcadeToolManager.
|
||||
|
||||
This mocks all relevant methods used by the manager, including:
|
||||
- tools.get
|
||||
- tools.list
|
||||
- tools.authorize
|
||||
- auth.status
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
# Mock the "tools" sub-client
|
||||
mock_client.tools.get = MagicMock()
|
||||
mock_client.tools.list = MagicMock()
|
||||
mock_client.tools.authorize = MagicMock()
|
||||
# Mock the "auth" sub-client
|
||||
mock_client.auth.status = MagicMock()
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(mock_arcade_client):
|
||||
"""
|
||||
A fixture that creates an ArcadeToolManager with the mocked Arcade client.
|
||||
"""
|
||||
return ArcadeToolManager(client=mock_arcade_client)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_tool():
|
||||
"""
|
||||
A factory fixture for creating a valid ToolDefinition with a given
|
||||
fully qualified name. Because the underlying ToolDefinition model
|
||||
expects "toolkit" to be a dictionary with at least one field (for example "slug"),
|
||||
and "requirements.authorization" to be a valid dictionary if present, we set them up
|
||||
accordingly.
|
||||
"""
|
||||
|
||||
def _make_tool(fully_qualified_name="Search_SearchGoogle", **kwargs):
|
||||
# Split on the first dot to derive a 'toolkit' slug and a tool 'name'
|
||||
if "." in fully_qualified_name:
|
||||
raw_toolkit, raw_tool_name = fully_qualified_name.split(".", 1)
|
||||
elif "_" in fully_qualified_name:
|
||||
# Convert from "_" to "." to match the expected format of tool name when
|
||||
# using Langchain models for LLM inference.
|
||||
raw_toolkit, raw_tool_name = fully_qualified_name.split("_", 1)
|
||||
|
||||
else:
|
||||
raw_toolkit, raw_tool_name = fully_qualified_name, fully_qualified_name
|
||||
|
||||
# Provide a default toolkit dict unless one already exists in kwargs
|
||||
toolkit = kwargs.pop("toolkit", {"name": raw_toolkit})
|
||||
|
||||
# Provide a default input
|
||||
# arcadepy.types.ToolGetResponse expects "input" to be a valid structure (dict).
|
||||
tool_input = kwargs.pop("input", {"parameters": []})
|
||||
|
||||
# Convert MagicMock-based requirements (with authorization) to an appropriate dict,
|
||||
# or use what's passed. If none is passed, default to None.
|
||||
requirements = kwargs.pop("requirements", None)
|
||||
if requirements is not None and not isinstance(requirements, dict):
|
||||
# If it's e.g. a MagicMock(authorization="xyz"), convert it to a dict
|
||||
req_auth = getattr(requirements, "authorization", None)
|
||||
# If the test expects an authorization presence, represent it as a dict
|
||||
# that Pydantic can parse
|
||||
if req_auth is not None:
|
||||
requirements = {"authorization": {"type": req_auth}}
|
||||
else:
|
||||
requirements = {"authorization": None}
|
||||
|
||||
# Provide a default description if none is supplied
|
||||
description = kwargs.pop("description", "Mock tool for testing")
|
||||
|
||||
# Build the pydantic fields
|
||||
data = {
|
||||
"fully_qualified_name": fully_qualified_name,
|
||||
"name": raw_tool_name,
|
||||
"toolkit": toolkit,
|
||||
"input": tool_input,
|
||||
"description": description,
|
||||
"requirements": requirements,
|
||||
}
|
||||
data.update(kwargs) # merge any extras
|
||||
|
||||
return ToolDefinition(**data)
|
||||
|
||||
return _make_tool
|
||||
|
||||
|
||||
def test_init_tools(manager, mock_arcade_client, make_tool):
|
||||
"""
|
||||
Test that init_tools clears any existing tools and retrieves new ones
|
||||
from either an explicit list of tools or an entire toolkit.
|
||||
"""
|
||||
# Arrange
|
||||
mock_tool = make_tool("Search_SearchGoogle")
|
||||
mock_arcade_client.tools.get.return_value = mock_tool
|
||||
mock_arcade_client.tools.list.return_value = SyncOffsetPage(items=[mock_tool])
|
||||
# Act
|
||||
manager.init_tools(tools=["Search_SearchGoogle"])
|
||||
|
||||
# Assert
|
||||
assert "Search_SearchGoogle" in manager.tools
|
||||
assert manager._tools["Search_SearchGoogle"] == mock_tool
|
||||
mock_arcade_client.tools.get.assert_called_once_with(name="Search_SearchGoogle")
|
||||
|
||||
|
||||
def test_get_tools_no_init(manager, mock_arcade_client, make_tool):
|
||||
"""
|
||||
If get_tools is called without init_tools and no tools are specified,
|
||||
it should call init_tools internally and fetch all available tools.
|
||||
"""
|
||||
# Arrange
|
||||
mock_tool = make_tool("Search_SearchGoogle")
|
||||
mock_arcade_client.tools.list.return_value = SyncOffsetPage(items=[mock_tool])
|
||||
|
||||
# Act
|
||||
tools = manager.get_tools() # no param means manager auto-inits
|
||||
|
||||
# Assert
|
||||
assert len(tools) == 1
|
||||
assert "Search_SearchGoogle" in manager.tools
|
||||
assert manager._tools["Search_SearchGoogle"] == mock_tool
|
||||
mock_arcade_client.tools.list.assert_called_once()
|
||||
|
||||
|
||||
def test_get_tools_with_explicit(manager, mock_arcade_client, make_tool):
|
||||
"""
|
||||
If tools or toolkits are provided to get_tools, the manager should
|
||||
retrieve or update the internal _tools dictionary accordingly,
|
||||
then return them as StructuredTool objects.
|
||||
"""
|
||||
# Arrange
|
||||
mock_tool_google = make_tool("Search_SearchGoogle")
|
||||
mock_tool_bing = make_tool("Search_SearchBing")
|
||||
mock_arcade_client.tools.get.side_effect = [mock_tool_google, mock_tool_bing]
|
||||
|
||||
# Act
|
||||
retrieved_tools = manager.get_tools(tools=["Search_SearchGoogle", "Search_SearchBing"])
|
||||
|
||||
# Assert
|
||||
assert len(retrieved_tools) == 2
|
||||
assert set(manager.tools) == {"Search_SearchGoogle", "Search_SearchBing"}
|
||||
mock_arcade_client.tools.get.assert_any_call(name="Search_SearchGoogle")
|
||||
mock_arcade_client.tools.get.assert_any_call(name="Search_SearchBing")
|
||||
|
||||
|
||||
def test_authorize(manager, mock_arcade_client):
|
||||
"""
|
||||
Test the authorize method to ensure it calls the Arcade client's
|
||||
tools.authorize method correctly.
|
||||
"""
|
||||
# Arrange
|
||||
mock_arcade_client.tools.authorize.return_value = AuthorizationResponse(
|
||||
id="auth_123", status="pending", tool_fully_qualified_name="Search_SearchGoogle"
|
||||
)
|
||||
|
||||
# Act
|
||||
response = manager.authorize(tool_name="Search_SearchGoogle", user_id="user_123")
|
||||
|
||||
# Assert
|
||||
assert response.id == "auth_123"
|
||||
assert response.status == "pending"
|
||||
mock_arcade_client.tools.authorize.assert_called_once_with(
|
||||
tool_name="Search_SearchGoogle", user_id="user_123"
|
||||
)
|
||||
|
||||
|
||||
def test_is_authorized(manager, mock_arcade_client):
|
||||
"""
|
||||
Test the is_authorized method which checks if authorization
|
||||
has completed for a given authorization ID.
|
||||
"""
|
||||
# Arrange
|
||||
mock_arcade_client.auth.status.return_value = MagicMock(status="completed")
|
||||
|
||||
# Act
|
||||
status_result = manager.is_authorized("auth_abc")
|
||||
|
||||
# Assert
|
||||
assert status_result is True
|
||||
mock_arcade_client.auth.status.assert_called_once_with(id="auth_abc")
|
||||
|
||||
|
||||
def test_requires_auth_true(manager, make_tool):
|
||||
"""
|
||||
Test the requires_auth method returning True if
|
||||
the stored tool definition's requirements contain an authorization entry.
|
||||
"""
|
||||
# Arrange
|
||||
tool_name = "Search_SearchGoogle"
|
||||
# Pass a MagicMock with 'authorization' to ensure it gets converted
|
||||
mock_tool_def = make_tool(tool_name, requirements=MagicMock(authorization="some_required_auth"))
|
||||
manager._tools[tool_name] = mock_tool_def
|
||||
|
||||
# Act
|
||||
result = manager.requires_auth(tool_name)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_requires_auth_false(manager, make_tool):
|
||||
"""
|
||||
Test the requires_auth method returning False if authorization
|
||||
is not required in the tool definition.
|
||||
"""
|
||||
# Arrange
|
||||
tool_name = "Search_SearchGoogle"
|
||||
mock_tool_def = make_tool(tool_name, requirements=MagicMock(authorization=None))
|
||||
manager._tools[tool_name] = mock_tool_def
|
||||
|
||||
# Act
|
||||
result = manager.requires_auth(tool_name)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_get_tool_definition_existing(manager, make_tool):
|
||||
"""
|
||||
Test the internal _get_tool_definition method retrieving
|
||||
an existing tool definition by name.
|
||||
"""
|
||||
# Arrange
|
||||
tool_name = "Search_SearchGoogle"
|
||||
mock_tool_def = make_tool(tool_name)
|
||||
manager._tools[tool_name] = mock_tool_def
|
||||
|
||||
# Act
|
||||
definition = manager._get_tool_definition(tool_name)
|
||||
|
||||
# Assert
|
||||
assert definition == mock_tool_def
|
||||
|
||||
|
||||
def test_get_tool_definition_missing(manager):
|
||||
"""
|
||||
Test the internal _get_tool_definition method raising a ValueError
|
||||
if the tool is not in the manager.
|
||||
"""
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
manager._get_tool_definition("Nonexistent.Tool")
|
||||
|
||||
assert "Tool 'Nonexistent.Tool' not found" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_retrieve_tool_definitions_tools_only(manager, mock_arcade_client, make_tool):
|
||||
"""
|
||||
Test the internal _retrieve_tool_definitions method by specifying tools only.
|
||||
"""
|
||||
# Arrange
|
||||
mock_tool = make_tool("Search_SearchGoogle")
|
||||
mock_arcade_client.tools.get.return_value = mock_tool
|
||||
|
||||
# Act
|
||||
results = manager._retrieve_tool_definitions(tools=["Search_SearchGoogle"], toolkits=None)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert "Search_SearchGoogle" in results
|
||||
mock_arcade_client.tools.get.assert_called_once_with(name="Search_SearchGoogle")
|
||||
|
||||
|
||||
def test_retrieve_tool_definitions_toolkits_only(manager, mock_arcade_client, make_tool):
|
||||
"""
|
||||
Test the internal _retrieve_tool_definitions method by specifying toolkits.
|
||||
"""
|
||||
# Arrange
|
||||
mock_tool = make_tool("Search_SearchBing")
|
||||
mock_arcade_client.tools.list.return_value = SyncOffsetPage(items=[mock_tool])
|
||||
|
||||
# Act
|
||||
results = manager._retrieve_tool_definitions(tools=None, toolkits=["Search"])
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert "Search_SearchBing" in results
|
||||
mock_arcade_client.tools.list.assert_called_once_with(toolkit="Search")
|
||||
|
||||
|
||||
def test_retrieve_tool_definitions_no_args(manager, mock_arcade_client, make_tool):
|
||||
"""
|
||||
Test the internal _retrieve_tool_definitions method when no
|
||||
arguments are provided, retrieving all available tools.
|
||||
"""
|
||||
# Arrange
|
||||
mock_tool1 = make_tool("Search_SearchGoogle")
|
||||
mock_tool2 = make_tool("Search_SearchBing")
|
||||
mock_arcade_client.tools.list.return_value = SyncOffsetPage(items=[mock_tool1, mock_tool2])
|
||||
|
||||
# Act
|
||||
results = manager._retrieve_tool_definitions()
|
||||
|
||||
# Assert
|
||||
assert len(results) == 2
|
||||
assert "Search_SearchGoogle" in results
|
||||
assert "Search_SearchBing" in results
|
||||
mock_arcade_client.tools.list.assert_called_once()
|
||||
10
contrib/langchain/tox.ini
Normal file
10
contrib/langchain/tox.ini
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
[tox]
|
||||
envlist = py310, py311, py312
|
||||
|
||||
[testenv]
|
||||
deps =
|
||||
pytest
|
||||
pytest-cov
|
||||
|
||||
commands =
|
||||
pytest --cov=langchain_arcade --cov-report=term-missing
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
# Import necessary classes and modules
|
||||
from langchain_arcade import ArcadeToolManager
|
||||
|
|
@ -15,7 +14,7 @@ openai_api_key = os.environ["OPENAI_API_KEY"]
|
|||
# Initialize the tool manager and fetch tools compatible with langgraph
|
||||
tool_manager = ArcadeToolManager(api_key=arcade_api_key)
|
||||
tools = tool_manager.get_tools(
|
||||
toolkits=["Github"],
|
||||
toolkits=["Github", "Google"],
|
||||
langgraph=True, # use langgraph-specific behavior
|
||||
)
|
||||
tool_node = ToolNode(tools)
|
||||
|
|
@ -25,8 +24,22 @@ model = ChatOpenAI(model="gpt-4o", api_key=openai_api_key)
|
|||
model_with_tools = model.bind_tools(tools)
|
||||
|
||||
|
||||
#### Helpers ####
|
||||
def get_nth_tool_call(state: MessagesState, n: int = 0):
|
||||
last_message = state["messages"][-1]
|
||||
return last_message.tool_calls[n]
|
||||
|
||||
|
||||
def has_tool_calls(state: MessagesState):
|
||||
last_message = state["messages"][-1]
|
||||
return last_message.tool_calls is not None and len(last_message.tool_calls) > 0
|
||||
|
||||
|
||||
#### Workflow ####
|
||||
|
||||
|
||||
# Function to invoke the model and get a response
|
||||
def call_agent(state):
|
||||
def call_agent(state: MessagesState):
|
||||
messages = state["messages"]
|
||||
response = model_with_tools.invoke(messages)
|
||||
# Return the updated message history
|
||||
|
|
@ -35,9 +48,8 @@ def call_agent(state):
|
|||
|
||||
# Function to determine the next step in the workflow based on the last message
|
||||
def should_continue(state: MessagesState):
|
||||
last_message = state["messages"][-1]
|
||||
if last_message.tool_calls:
|
||||
tool_name = last_message.tool_calls[0]["name"]
|
||||
if has_tool_calls(state):
|
||||
tool_name = get_nth_tool_call(state)["name"]
|
||||
if tool_manager.requires_auth(tool_name):
|
||||
return "authorization" # Proceed to authorization if required
|
||||
else:
|
||||
|
|
@ -48,54 +60,57 @@ def should_continue(state: MessagesState):
|
|||
# Function to handle authorization for tools that require it
|
||||
def authorize(state: MessagesState, config: dict):
|
||||
user_id = config["configurable"].get("user_id")
|
||||
tool_name = state["messages"][-1].tool_calls[0]["name"]
|
||||
tool_name = get_nth_tool_call(state)["name"]
|
||||
auth_response = tool_manager.authorize(tool_name, user_id)
|
||||
if auth_response.status == "completed":
|
||||
# Authorization completed successfully; continue
|
||||
return {"messages": state["messages"]}
|
||||
else:
|
||||
if auth_response.status != "completed":
|
||||
# Prompt the user to visit the authorization URL
|
||||
print(f"Visit the following URL to authorize: {auth_response.authorization_url}")
|
||||
# Wait until authorization is completed
|
||||
while not tool_manager.is_authorized(auth_response.authorization_id):
|
||||
time.sleep(1)
|
||||
return {"messages": state["messages"]}
|
||||
print(f"Visit the following URL to authorize: {auth_response.url}")
|
||||
|
||||
# wait for the user to complete the authorization
|
||||
# and then check the authorization status again
|
||||
tool_manager.wait_for_auth(auth_response.id)
|
||||
if not tool_manager.is_authorized(auth_response.id):
|
||||
# node interrupt?
|
||||
raise ValueError("Authorization failed")
|
||||
|
||||
return {"messages": state["messages"]}
|
||||
|
||||
|
||||
# Build the workflow graph using StateGraph
|
||||
workflow = StateGraph(MessagesState)
|
||||
if __name__ == "__main__":
|
||||
# Build the workflow graph using StateGraph
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
# Add nodes (steps) to the graph
|
||||
workflow.add_node("agent", call_agent)
|
||||
workflow.add_node("tools", tool_node)
|
||||
workflow.add_node("authorization", authorize)
|
||||
# Add nodes (steps) to the graph
|
||||
workflow.add_node("agent", call_agent)
|
||||
workflow.add_node("tools", tool_node)
|
||||
workflow.add_node("authorization", authorize)
|
||||
|
||||
# Define the edges and control flow between nodes
|
||||
workflow.add_edge(START, "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue, ["authorization", "tools", END])
|
||||
workflow.add_edge("authorization", "tools")
|
||||
workflow.add_edge("tools", "agent")
|
||||
# Define the edges and control flow between nodes
|
||||
workflow.add_edge(START, "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue, ["authorization", "tools", END])
|
||||
workflow.add_edge("authorization", "tools")
|
||||
workflow.add_edge("tools", "agent")
|
||||
|
||||
# Set up memory for checkpointing the state
|
||||
memory = MemorySaver()
|
||||
# Set up memory for checkpointing the state
|
||||
memory = MemorySaver()
|
||||
|
||||
# Compile the graph with the checkpointer
|
||||
graph = workflow.compile(checkpointer=memory)
|
||||
# Compile the graph with the checkpointer
|
||||
graph = workflow.compile(checkpointer=memory)
|
||||
|
||||
# Define the input messages from the user
|
||||
inputs = {
|
||||
"messages": [HumanMessage(content="Star arcadeai/arcade-ai on GitHub!")],
|
||||
}
|
||||
|
||||
# Configuration with thread and user IDs for authorization purposes
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": "4",
|
||||
"user_id": "user@example.com",
|
||||
# Define the input messages from the user
|
||||
inputs = {
|
||||
"messages": [HumanMessage(content="what's on my calendar today?")],
|
||||
}
|
||||
}
|
||||
|
||||
# Run the graph and stream the outputs
|
||||
for chunk in graph.stream(inputs, config=config, stream_mode="values"):
|
||||
# Pretty-print the last message in the chunk
|
||||
chunk["messages"][-1].pretty_print()
|
||||
# Configuration with thread and user IDs for authorization purposes
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": "4",
|
||||
"user_id": "user@example.comd",
|
||||
}
|
||||
}
|
||||
|
||||
# Run the graph and stream the outputs
|
||||
for chunk in graph.stream(inputs, config=config, stream_mode="values"):
|
||||
# Pretty-print the last message in the chunk
|
||||
chunk["messages"][-1].pretty_print()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
langchain>=0.3.0
|
||||
arcadepy>=0.1.2
|
||||
langchain-google-community[gmail]>=0.1.1
|
||||
langchain-openai>=0.1.1
|
||||
langgraph>=0.1.1
|
||||
langchain-arcade[langgraph]>=0.2.0
|
||||
|
|
|
|||
|
|
@ -39,4 +39,5 @@ for chunk in graph.stream(inputs, stream_mode="values", config=config):
|
|||
# Access the latest message from the conversation
|
||||
last_message = chunk["messages"][-1]
|
||||
# Print the assistant's message content
|
||||
print(last_message.content)
|
||||
if last_message.content:
|
||||
print(last_message.content)
|
||||
|
|
|
|||
Loading…
Reference in a new issue