Ignore Toolkits (#219)

This commit is contained in:
Sterling Dreyer 2025-01-23 15:37:15 -08:00 committed by GitHub
parent 09a0784cd5
commit 130858a958
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 42 additions and 0 deletions

View file

@ -118,10 +118,12 @@ class ToolCatalog(BaseModel):
_tools: dict[FullyQualifiedName, MaterializedTool] = {}
_disabled_tools: set[str] = set()
_disabled_toolkits: set[str] = set()
def __init__(self, **data) -> None: # type: ignore[no-untyped-def]
super().__init__(**data)
self._load_disabled_tools()
self._load_disabled_toolkits()
def _load_disabled_tools(self) -> None:
"""Load disabled tools from the environment variable.
@ -145,6 +147,23 @@ class ToolCatalog(BaseModel):
self._disabled_tools.add(tool.lower())
def _load_disabled_toolkits(self) -> None:
"""Load disabled toolkits from the environment variable.
The ARCADE_DISABLED_TOOLKITS environment variable should contain a
comma-separated list of toolkits that are to be excluded from the
catalog.
The expected format for each disabled toolkit is:
- [CamelCaseToolkitName]
"""
disabled_toolkits = os.getenv("ARCADE_DISABLED_TOOLKITS", "").strip().split(",")
if not disabled_toolkits:
return
for toolkit in disabled_toolkits:
self._disabled_toolkits.add(toolkit.lower())
def add_tool(
self,
tool_func: Callable,
@ -183,6 +202,10 @@ class ToolCatalog(BaseModel):
logger.info(f"Tool '{fully_qualified_name!s}' is disabled and will not be cataloged.")
return
if str(toolkit_name).lower() in self._disabled_toolkits:
logger.info(f"Toolkit '{toolkit_name!s}' is disabled and will not be cataloged.")
return
self._tools[fully_qualified_name] = MaterializedTool(
definition=definition,
tool=tool_func,
@ -208,6 +231,10 @@ class ToolCatalog(BaseModel):
Add the tools from a loaded toolkit to the catalog.
"""
if str(toolkit).lower() in self._disabled_toolkits:
logger.info(f"Toolkit '{toolkit.name!s}' is disabled and will not be cataloged.")
return
for module_name, tool_names in toolkit.tools.items():
for tool_name in tool_names:
try:

View file

@ -180,3 +180,18 @@ def test_add_tool_with_whitespace_disabled_tools(monkeypatch):
catalog = ToolCatalog()
catalog.add_tool(sample_tool, "SampleToolkitOne")
assert len(catalog._tools) == 1
def test_add_tool_with_disabled_toolkit(monkeypatch):
monkeypatch.setenv("ARCADE_DISABLED_TOOLKITS", "SampleToolkitOne")
catalog = ToolCatalog()
catalog.add_toolkit(
Toolkit(
name="SampleToolkitOne",
package_name="sample_toolkit_one",
version="1.0.0",
description="A sample toolkit",
)
)
assert len(catalog._tools) == 0