From 1b67cee6672c30acbad32abe66e3fee926128116 Mon Sep 17 00:00:00 2001 From: Nate Barbettini Date: Thu, 1 Aug 2024 09:14:37 -0700 Subject: [PATCH] JWT auth for Engine->Actor communication (#11) Implements: https://app.clickup.com/9014390315/v/dc/8cmtbhb-2714/8cmtbhb-5974 Todo: - [x] Initial demo - [x] Get API key config from `arcade.config` - [x] Get engine URL from config - [x] Final cleanup - [ ] Enforce auth for all requests (waiting for engine) --- arcade/arcade/actor/core/__init__.py | 0 arcade/arcade/actor/core/auth.py | 43 +++++ arcade/arcade/actor/{ => core}/base.py | 0 arcade/arcade/actor/fastapi/actor.py | 12 +- arcade/arcade/actor/fastapi/auth.py | 26 +++ arcade/arcade/core/config.py | 167 ++++++++++++++---- arcade/poetry.lock | 19 +- arcade/pyproject.toml | 1 + examples/math/arcade_arithmetic/main.py | 4 +- .../arcade_arithmetic/tools/arithmetic.py | 10 ++ 10 files changed, 239 insertions(+), 43 deletions(-) create mode 100644 arcade/arcade/actor/core/__init__.py create mode 100644 arcade/arcade/actor/core/auth.py rename arcade/arcade/actor/{ => core}/base.py (100%) diff --git a/arcade/arcade/actor/core/__init__.py b/arcade/arcade/actor/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/arcade/arcade/actor/core/auth.py b/arcade/arcade/actor/core/auth.py new file mode 100644 index 00000000..8216c6a4 --- /dev/null +++ b/arcade/arcade/actor/core/auth.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from enum import Enum + +import jwt + +from arcade.core.config import config + +TOKEN_VER = "1" # noqa: S105 Possible hardcoded password assigned (false positive) + + +@dataclass +class TokenValidationResult: + valid: bool + api_key: str | None = None + error: str | None = None + + +class SigningAlgorithm(str, Enum): + HS256 = "HS256" + + +def validate_token(token: str) -> TokenValidationResult: + try: + payload = jwt.decode( + token, + config.api.secret, + algorithms=[SigningAlgorithm.HS256], + verify=True, + issuer=config.engine_url, + audience="actor", + ) + except (jwt.ExpiredSignatureError, jwt.InvalidTokenError) as e: + return TokenValidationResult(valid=False, error=str(e)) + + api_key = payload.get("api_key") + if api_key != config.api.key: + return TokenValidationResult(valid=False, error="Invalid API key") + + token_ver = payload.get("ver") + if token_ver != TOKEN_VER: + return TokenValidationResult(valid=False, error=f"Unknown token version: {token_ver}") + + return TokenValidationResult(valid=True, api_key=api_key) diff --git a/arcade/arcade/actor/base.py b/arcade/arcade/actor/core/base.py similarity index 100% rename from arcade/arcade/actor/base.py rename to arcade/arcade/actor/core/base.py diff --git a/arcade/arcade/actor/fastapi/actor.py b/arcade/arcade/actor/fastapi/actor.py index 8346d5f6..f9b0fced 100644 --- a/arcade/arcade/actor/fastapi/actor.py +++ b/arcade/arcade/actor/fastapi/actor.py @@ -3,7 +3,7 @@ from typing import Any, Callable from fastapi import FastAPI, Request -from arcade.actor.base import BaseActor +from arcade.actor.core.base import BaseActor class FastAPIActor(BaseActor): @@ -14,13 +14,14 @@ class FastAPIActor(BaseActor): """ super().__init__() self.app = app - self.router = FastAPIRouter(app) + self.router = FastAPIRouter(app, self) self.register_routes(self.router) class FastAPIRouter: # TODO create an interface for this - def __init__(self, app: FastAPI) -> None: + def __init__(self, app: FastAPI, actor: BaseActor) -> None: self.app = app + self.actor = actor def add_route(self, path: str, handler: Callable, methods: str) -> None: """ @@ -43,7 +44,10 @@ class FastAPIRouter: # TODO create an interface for this Wrap the handler to handle FastAPI-specific request and response. """ - async def wrapped_handler(request: Request) -> Any: + async def wrapped_handler( + request: Request, + # api_key: str = Depends(get_api_key), # TODO re-enable when Engine supports auth + ) -> Any: if asyncio.iscoroutinefunction(handler) or ( callable(handler) and asyncio.iscoroutinefunction(handler.__call__) # type: ignore[operator] ): diff --git a/arcade/arcade/actor/fastapi/auth.py b/arcade/arcade/actor/fastapi/auth.py index e69de29b..126a8fe3 100644 --- a/arcade/arcade/actor/fastapi/auth.py +++ b/arcade/arcade/actor/fastapi/auth.py @@ -0,0 +1,26 @@ +from typing import cast + +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from arcade.actor.core.auth import validate_token + +security = HTTPBearer() # Authorization: Bearer + + +# Dependency function to validate JWT and extract API key +# The validator function is provided by the BaseActor class +async def get_api_key( + credentials: HTTPAuthorizationCredentials = Depends(security), +) -> str: + jwt: str = credentials.credentials + validation_result = validate_token(jwt) + + if not validation_result.valid: + raise HTTPException( + status_code=401, + detail=f"Invalid token. Error: {validation_result.error}", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return cast(str, validation_result.api_key) diff --git a/arcade/arcade/core/config.py b/arcade/arcade/core/config.py index b163e62f..5139a6c9 100644 --- a/arcade/arcade/core/config.py +++ b/arcade/arcade/core/config.py @@ -1,63 +1,156 @@ from pathlib import Path import toml -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from arcade.core.env import settings +class ApiConfig(BaseModel): + """ + Arcade API configuration. + """ + + key: str + """ + Arcade API key. + """ + secret: str + """ + Arcade API secret. + """ + + +class UserConfig(BaseModel): + """ + Arcade user configuration. + """ + + email: str | None = None + """ + User email. + """ + + +class EngineConfig(BaseModel): + """ + Arcade Engine configuration. + """ + + host: str = "localhost" + """ + Arcade Engine host. + """ + port: int = 6901 + """ + Arcade Engine port. + """ + tls: bool = False + """ + Whether to use TLS for the connection to Arcade Engine. + """ + + class Config(BaseModel): - api_key: str | None = None - user_email: str | None = None - engine_key: str | None = None - engine_host: str = "localhost" - engine_port: str = "6901" + """ + Configuration for Arcade. + """ - config_dir: Path = settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade" - config_file: Path = config_dir / "arcade.toml" + api: ApiConfig + """ + Arcade API configuration. + """ + user: UserConfig | None = None + """ + Arcade user configuration. + """ + engine: EngineConfig | None = None + """ + Arcade Engine configuration. + """ + + @classmethod + def get_config_dir_path(cls) -> Path: + """ + Get the path to the Arcade configuration directory. + """ + return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade" + + @classmethod + def get_config_file_path(cls) -> Path: + """ + Get the path to the Arcade configuration file. + """ + return cls.get_config_dir_path() / "arcade.toml" @property - def arcade_api_key(self) -> str: - if not self.api_key: - raise ValueError("Arcade API Key not set") - return self.api_key + def engine_url(self) -> str: + """ + Get the URL of the Arcade Engine. + """ + if self.engine is None: + raise ValueError("Engine not set") + protocol = "https" if self.engine.tls else "http" + return f"{protocol}://{self.engine.host}:{self.engine.port}" - @property - def engine_url(self, tls: bool = False) -> str: - if tls: - return f"https://{self.engine_host}:{self.engine_port}" - return f"http://{self.engine_host}:{self.engine_port}" - - @staticmethod - def create_config_directory() -> None: + @classmethod + def ensure_config_dir_exists(cls) -> None: """ Create the configuration directory if it does not exist. """ - config_dir = Config.config_dir + config_dir = Config.get_config_dir_path() if not config_dir.exists(): config_dir.mkdir(parents=True, exist_ok=True) - def save_to_file(self) -> None: - """ - Save the configuration to the TOML file in the configuration directory. - """ - self.create_config_directory() - config_file_path = self.config_file - with config_file_path.open("w") as config_file: - toml.dump(self.dict(), config_file) - @classmethod def load_from_file(cls) -> "Config": """ Load the configuration from the TOML file in the configuration directory. + If no configuration file exists, create a new one with default values. """ - cls.create_config_directory() - config_file_path = cls.config_file - if config_file_path.exists(): - with config_file_path.open("r") as config_file: - config_data = toml.load(config_file) - return cls(**config_data) - return cls() + cls.ensure_config_dir_exists() + + config_file_path = cls.get_config_file_path() + if not config_file_path.exists(): + # Create a file using the default configuration + default_config = cls.model_construct( + api=ApiConfig.model_construct(), engine=EngineConfig() + ) + default_config.save_to_file() + + config_data = toml.loads(config_file_path.read_text()) + + try: + return cls(**config_data) + except ValidationError as e: + # Get only the errors with {type:missing} and combine them + # into a nicely-formatted string message. + # Any other errors without {type:missing} should just be str()ed + missing_field_errors = [ + ".".join(map(str, error["loc"])) + for error in e.errors() + if error["type"] == "missing" + ] + other_errors = [str(error) for error in e.errors() if error["type"] != "missing"] + + missing_field_errors_str = ", ".join(missing_field_errors) + other_errors_str = "\n".join(other_errors) + + pretty_str: str = "Invalid Arcade configuration." + if missing_field_errors_str: + pretty_str += f"\nMissing fields: {missing_field_errors_str}\n" + if other_errors_str: + pretty_str += f"\nOther errors:\n{other_errors_str}" + + raise ValueError(pretty_str) from e + + def save_to_file(self) -> None: + """ + Save the configuration to the TOML file in the configuration directory. + """ + Config.ensure_config_dir_exists() + config_file_path = Config.get_config_file_path() + config_file_path.write_text(toml.dumps(self.model_dump())) # Singleton instance of Config diff --git a/arcade/poetry.lock b/arcade/poetry.lock index e31da5c4..efc5fec4 100644 --- a/arcade/poetry.lock +++ b/arcade/poetry.lock @@ -1124,6 +1124,23 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyjwt" +version = "2.8.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, + {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pymdown-extensions" version = "10.8.1" @@ -1694,4 +1711,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "8ae884652ca236023f88f4e4074cb831259b562ae634c62d298c39d1c47c5999" +content-hash = "36b615b9076322778f3195849116d9248c088df80410c1424164b89ee534e93e" diff --git a/arcade/pyproject.toml b/arcade/pyproject.toml index ca761eff..1d2b8437 100644 --- a/arcade/pyproject.toml +++ b/arcade/pyproject.toml @@ -24,6 +24,7 @@ tomlkit = "^0.12.4" requests = "^2.26.0" openai = "^1.36.0" +pyjwt = "^2.8.0" [tool.poetry.group.dev.dependencies] pytest = "^7.2.0" pytest-cov = "^4.0.0" diff --git a/examples/math/arcade_arithmetic/main.py b/examples/math/arcade_arithmetic/main.py index cc1d844f..26c5966e 100644 --- a/examples/math/arcade_arithmetic/main.py +++ b/examples/math/arcade_arithmetic/main.py @@ -2,7 +2,8 @@ from arcade.actor.fastapi.actor import FastAPIActor from fastapi import FastAPI, HTTPException from pydantic import BaseModel from openai import AsyncOpenAI -from arcade_example_nate.tools import arithmetic + +from tools import arithmetic client = AsyncOpenAI(base_url="http://localhost:6901") @@ -30,6 +31,7 @@ async def chat(request: ChatRequest): ], model="gpt-4o-mini", max_tokens=150, + tools=["add", "subtract", "multiply", "divide", "sqrt"], tool_choice="execute", ) chat_completion = raw_response.parse() diff --git a/examples/math/arcade_arithmetic/tools/arithmetic.py b/examples/math/arcade_arithmetic/tools/arithmetic.py index 31c472bb..7f8433b3 100644 --- a/examples/math/arcade_arithmetic/tools/arithmetic.py +++ b/examples/math/arcade_arithmetic/tools/arithmetic.py @@ -14,6 +14,16 @@ def add( return a + b +@tool +def subtract( + a: Annotated[int, "The first number"], b: Annotated[int, "The second number"] +) -> Annotated[int, "The difference of the two numbers"]: + """ + Subtract two numbers + """ + return a - b + + @tool def multiply( a: Annotated[int, "The first number"], b: Annotated[int, "The second number"]