JWT auth for Engine->Actor communication (#11)

Implements:
https://app.clickup.com/9014390315/v/dc/8cmtbhb-2714/8cmtbhb-5974

Todo:
- [x] Initial demo
- [x] Get API key config from `arcade.config`
- [x] Get engine URL from config
- [x] Final cleanup
- [ ] Enforce auth for all requests (waiting for engine)
This commit is contained in:
Nate Barbettini 2024-08-01 09:14:37 -07:00 committed by GitHub
parent 41cc749a6e
commit 1b67cee667
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 239 additions and 43 deletions

View file

View file

@ -0,0 +1,43 @@
from dataclasses import dataclass
from enum import Enum
import jwt
from arcade.core.config import config
TOKEN_VER = "1" # noqa: S105 Possible hardcoded password assigned (false positive)
@dataclass
class TokenValidationResult:
valid: bool
api_key: str | None = None
error: str | None = None
class SigningAlgorithm(str, Enum):
HS256 = "HS256"
def validate_token(token: str) -> TokenValidationResult:
try:
payload = jwt.decode(
token,
config.api.secret,
algorithms=[SigningAlgorithm.HS256],
verify=True,
issuer=config.engine_url,
audience="actor",
)
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError) as e:
return TokenValidationResult(valid=False, error=str(e))
api_key = payload.get("api_key")
if api_key != config.api.key:
return TokenValidationResult(valid=False, error="Invalid API key")
token_ver = payload.get("ver")
if token_ver != TOKEN_VER:
return TokenValidationResult(valid=False, error=f"Unknown token version: {token_ver}")
return TokenValidationResult(valid=True, api_key=api_key)

View file

@ -3,7 +3,7 @@ from typing import Any, Callable
from fastapi import FastAPI, Request
from arcade.actor.base import BaseActor
from arcade.actor.core.base import BaseActor
class FastAPIActor(BaseActor):
@ -14,13 +14,14 @@ class FastAPIActor(BaseActor):
"""
super().__init__()
self.app = app
self.router = FastAPIRouter(app)
self.router = FastAPIRouter(app, self)
self.register_routes(self.router)
class FastAPIRouter: # TODO create an interface for this
def __init__(self, app: FastAPI) -> None:
def __init__(self, app: FastAPI, actor: BaseActor) -> None:
self.app = app
self.actor = actor
def add_route(self, path: str, handler: Callable, methods: str) -> None:
"""
@ -43,7 +44,10 @@ class FastAPIRouter: # TODO create an interface for this
Wrap the handler to handle FastAPI-specific request and response.
"""
async def wrapped_handler(request: Request) -> Any:
async def wrapped_handler(
request: Request,
# api_key: str = Depends(get_api_key), # TODO re-enable when Engine supports auth
) -> Any:
if asyncio.iscoroutinefunction(handler) or (
callable(handler) and asyncio.iscoroutinefunction(handler.__call__) # type: ignore[operator]
):

View file

@ -0,0 +1,26 @@
from typing import cast
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from arcade.actor.core.auth import validate_token
security = HTTPBearer() # Authorization: Bearer <xxx>
# Dependency function to validate JWT and extract API key
# The validator function is provided by the BaseActor class
async def get_api_key(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> str:
jwt: str = credentials.credentials
validation_result = validate_token(jwt)
if not validation_result.valid:
raise HTTPException(
status_code=401,
detail=f"Invalid token. Error: {validation_result.error}",
headers={"WWW-Authenticate": "Bearer"},
)
return cast(str, validation_result.api_key)

View file

@ -1,63 +1,156 @@
from pathlib import Path
import toml
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from arcade.core.env import settings
class ApiConfig(BaseModel):
"""
Arcade API configuration.
"""
key: str
"""
Arcade API key.
"""
secret: str
"""
Arcade API secret.
"""
class UserConfig(BaseModel):
"""
Arcade user configuration.
"""
email: str | None = None
"""
User email.
"""
class EngineConfig(BaseModel):
"""
Arcade Engine configuration.
"""
host: str = "localhost"
"""
Arcade Engine host.
"""
port: int = 6901
"""
Arcade Engine port.
"""
tls: bool = False
"""
Whether to use TLS for the connection to Arcade Engine.
"""
class Config(BaseModel):
api_key: str | None = None
user_email: str | None = None
engine_key: str | None = None
engine_host: str = "localhost"
engine_port: str = "6901"
"""
Configuration for Arcade.
"""
config_dir: Path = settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
config_file: Path = config_dir / "arcade.toml"
api: ApiConfig
"""
Arcade API configuration.
"""
user: UserConfig | None = None
"""
Arcade user configuration.
"""
engine: EngineConfig | None = None
"""
Arcade Engine configuration.
"""
@classmethod
def get_config_dir_path(cls) -> Path:
"""
Get the path to the Arcade configuration directory.
"""
return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
@classmethod
def get_config_file_path(cls) -> Path:
"""
Get the path to the Arcade configuration file.
"""
return cls.get_config_dir_path() / "arcade.toml"
@property
def arcade_api_key(self) -> str:
if not self.api_key:
raise ValueError("Arcade API Key not set")
return self.api_key
def engine_url(self) -> str:
"""
Get the URL of the Arcade Engine.
"""
if self.engine is None:
raise ValueError("Engine not set")
protocol = "https" if self.engine.tls else "http"
return f"{protocol}://{self.engine.host}:{self.engine.port}"
@property
def engine_url(self, tls: bool = False) -> str:
if tls:
return f"https://{self.engine_host}:{self.engine_port}"
return f"http://{self.engine_host}:{self.engine_port}"
@staticmethod
def create_config_directory() -> None:
@classmethod
def ensure_config_dir_exists(cls) -> None:
"""
Create the configuration directory if it does not exist.
"""
config_dir = Config.config_dir
config_dir = Config.get_config_dir_path()
if not config_dir.exists():
config_dir.mkdir(parents=True, exist_ok=True)
def save_to_file(self) -> None:
"""
Save the configuration to the TOML file in the configuration directory.
"""
self.create_config_directory()
config_file_path = self.config_file
with config_file_path.open("w") as config_file:
toml.dump(self.dict(), config_file)
@classmethod
def load_from_file(cls) -> "Config":
"""
Load the configuration from the TOML file in the configuration directory.
If no configuration file exists, create a new one with default values.
"""
cls.create_config_directory()
config_file_path = cls.config_file
if config_file_path.exists():
with config_file_path.open("r") as config_file:
config_data = toml.load(config_file)
return cls(**config_data)
return cls()
cls.ensure_config_dir_exists()
config_file_path = cls.get_config_file_path()
if not config_file_path.exists():
# Create a file using the default configuration
default_config = cls.model_construct(
api=ApiConfig.model_construct(), engine=EngineConfig()
)
default_config.save_to_file()
config_data = toml.loads(config_file_path.read_text())
try:
return cls(**config_data)
except ValidationError as e:
# Get only the errors with {type:missing} and combine them
# into a nicely-formatted string message.
# Any other errors without {type:missing} should just be str()ed
missing_field_errors = [
".".join(map(str, error["loc"]))
for error in e.errors()
if error["type"] == "missing"
]
other_errors = [str(error) for error in e.errors() if error["type"] != "missing"]
missing_field_errors_str = ", ".join(missing_field_errors)
other_errors_str = "\n".join(other_errors)
pretty_str: str = "Invalid Arcade configuration."
if missing_field_errors_str:
pretty_str += f"\nMissing fields: {missing_field_errors_str}\n"
if other_errors_str:
pretty_str += f"\nOther errors:\n{other_errors_str}"
raise ValueError(pretty_str) from e
def save_to_file(self) -> None:
"""
Save the configuration to the TOML file in the configuration directory.
"""
Config.ensure_config_dir_exists()
config_file_path = Config.get_config_file_path()
config_file_path.write_text(toml.dumps(self.model_dump()))
# Singleton instance of Config

19
arcade/poetry.lock generated
View file

@ -1124,6 +1124,23 @@ files = [
[package.extras]
windows-terminal = ["colorama (>=0.4.6)"]
[[package]]
name = "pyjwt"
version = "2.8.0"
description = "JSON Web Token implementation in Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
{file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
]
[package.extras]
crypto = ["cryptography (>=3.4.0)"]
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
[[package]]
name = "pymdown-extensions"
version = "10.8.1"
@ -1694,4 +1711,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
content-hash = "8ae884652ca236023f88f4e4074cb831259b562ae634c62d298c39d1c47c5999"
content-hash = "36b615b9076322778f3195849116d9248c088df80410c1424164b89ee534e93e"

View file

@ -24,6 +24,7 @@ tomlkit = "^0.12.4"
requests = "^2.26.0"
openai = "^1.36.0"
pyjwt = "^2.8.0"
[tool.poetry.group.dev.dependencies]
pytest = "^7.2.0"
pytest-cov = "^4.0.0"

View file

@ -2,7 +2,8 @@ from arcade.actor.fastapi.actor import FastAPIActor
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from openai import AsyncOpenAI
from arcade_example_nate.tools import arithmetic
from tools import arithmetic
client = AsyncOpenAI(base_url="http://localhost:6901")
@ -30,6 +31,7 @@ async def chat(request: ChatRequest):
],
model="gpt-4o-mini",
max_tokens=150,
tools=["add", "subtract", "multiply", "divide", "sqrt"],
tool_choice="execute",
)
chat_completion = raw_response.parse()

View file

@ -14,6 +14,16 @@ def add(
return a + b
@tool
def subtract(
a: Annotated[int, "The first number"], b: Annotated[int, "The second number"]
) -> Annotated[int, "The difference of the two numbers"]:
"""
Subtract two numbers
"""
return a - b
@tool
def multiply(
a: Annotated[int, "The first number"], b: Annotated[int, "The second number"]