diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index 4602a930..5ea7042a 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -118,10 +118,12 @@ class ToolCatalog(BaseModel): _tools: dict[FullyQualifiedName, MaterializedTool] = {} _disabled_tools: set[str] = set() + _disabled_toolkits: set[str] = set() def __init__(self, **data) -> None: # type: ignore[no-untyped-def] super().__init__(**data) self._load_disabled_tools() + self._load_disabled_toolkits() def _load_disabled_tools(self) -> None: """Load disabled tools from the environment variable. @@ -145,6 +147,23 @@ class ToolCatalog(BaseModel): self._disabled_tools.add(tool.lower()) + def _load_disabled_toolkits(self) -> None: + """Load disabled toolkits from the environment variable. + + The ARCADE_DISABLED_TOOLKITS environment variable should contain a + comma-separated list of toolkits that are to be excluded from the + catalog. + + The expected format for each disabled toolkit is: + - [CamelCaseToolkitName] + """ + disabled_toolkits = os.getenv("ARCADE_DISABLED_TOOLKITS", "").strip().split(",") + if not disabled_toolkits: + return + + for toolkit in disabled_toolkits: + self._disabled_toolkits.add(toolkit.lower()) + def add_tool( self, tool_func: Callable, @@ -183,6 +202,10 @@ class ToolCatalog(BaseModel): logger.info(f"Tool '{fully_qualified_name!s}' is disabled and will not be cataloged.") return + if str(toolkit_name).lower() in self._disabled_toolkits: + logger.info(f"Toolkit '{toolkit_name!s}' is disabled and will not be cataloged.") + return + self._tools[fully_qualified_name] = MaterializedTool( definition=definition, tool=tool_func, @@ -208,6 +231,10 @@ class ToolCatalog(BaseModel): Add the tools from a loaded toolkit to the catalog. """ + if str(toolkit).lower() in self._disabled_toolkits: + logger.info(f"Toolkit '{toolkit.name!s}' is disabled and will not be cataloged.") + return + for module_name, tool_names in toolkit.tools.items(): for tool_name in tool_names: try: diff --git a/arcade/tests/core/test_catalog.py b/arcade/tests/core/test_catalog.py index 45eb2524..a2a0bb7a 100644 --- a/arcade/tests/core/test_catalog.py +++ b/arcade/tests/core/test_catalog.py @@ -180,3 +180,18 @@ def test_add_tool_with_whitespace_disabled_tools(monkeypatch): catalog = ToolCatalog() catalog.add_tool(sample_tool, "SampleToolkitOne") assert len(catalog._tools) == 1 + + +def test_add_tool_with_disabled_toolkit(monkeypatch): + monkeypatch.setenv("ARCADE_DISABLED_TOOLKITS", "SampleToolkitOne") + catalog = ToolCatalog() + + catalog.add_toolkit( + Toolkit( + name="SampleToolkitOne", + package_name="sample_toolkit_one", + version="1.0.0", + description="A sample toolkit", + ) + ) + assert len(catalog._tools) == 0