JWT auth for Engine->Actor communication (#11)
Implements: https://app.clickup.com/9014390315/v/dc/8cmtbhb-2714/8cmtbhb-5974 Todo: - [x] Initial demo - [x] Get API key config from `arcade.config` - [x] Get engine URL from config - [x] Final cleanup - [ ] Enforce auth for all requests (waiting for engine)
This commit is contained in:
parent
41cc749a6e
commit
1b67cee667
10 changed files with 239 additions and 43 deletions
0
arcade/arcade/actor/core/__init__.py
Normal file
0
arcade/arcade/actor/core/__init__.py
Normal file
43
arcade/arcade/actor/core/auth.py
Normal file
43
arcade/arcade/actor/core/auth.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import jwt
|
||||
|
||||
from arcade.core.config import config
|
||||
|
||||
TOKEN_VER = "1" # noqa: S105 Possible hardcoded password assigned (false positive)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenValidationResult:
|
||||
valid: bool
|
||||
api_key: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class SigningAlgorithm(str, Enum):
|
||||
HS256 = "HS256"
|
||||
|
||||
|
||||
def validate_token(token: str) -> TokenValidationResult:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
config.api.secret,
|
||||
algorithms=[SigningAlgorithm.HS256],
|
||||
verify=True,
|
||||
issuer=config.engine_url,
|
||||
audience="actor",
|
||||
)
|
||||
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError) as e:
|
||||
return TokenValidationResult(valid=False, error=str(e))
|
||||
|
||||
api_key = payload.get("api_key")
|
||||
if api_key != config.api.key:
|
||||
return TokenValidationResult(valid=False, error="Invalid API key")
|
||||
|
||||
token_ver = payload.get("ver")
|
||||
if token_ver != TOKEN_VER:
|
||||
return TokenValidationResult(valid=False, error=f"Unknown token version: {token_ver}")
|
||||
|
||||
return TokenValidationResult(valid=True, api_key=api_key)
|
||||
|
|
@ -3,7 +3,7 @@ from typing import Any, Callable
|
|||
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from arcade.actor.base import BaseActor
|
||||
from arcade.actor.core.base import BaseActor
|
||||
|
||||
|
||||
class FastAPIActor(BaseActor):
|
||||
|
|
@ -14,13 +14,14 @@ class FastAPIActor(BaseActor):
|
|||
"""
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.router = FastAPIRouter(app)
|
||||
self.router = FastAPIRouter(app, self)
|
||||
self.register_routes(self.router)
|
||||
|
||||
|
||||
class FastAPIRouter: # TODO create an interface for this
|
||||
def __init__(self, app: FastAPI) -> None:
|
||||
def __init__(self, app: FastAPI, actor: BaseActor) -> None:
|
||||
self.app = app
|
||||
self.actor = actor
|
||||
|
||||
def add_route(self, path: str, handler: Callable, methods: str) -> None:
|
||||
"""
|
||||
|
|
@ -43,7 +44,10 @@ class FastAPIRouter: # TODO create an interface for this
|
|||
Wrap the handler to handle FastAPI-specific request and response.
|
||||
"""
|
||||
|
||||
async def wrapped_handler(request: Request) -> Any:
|
||||
async def wrapped_handler(
|
||||
request: Request,
|
||||
# api_key: str = Depends(get_api_key), # TODO re-enable when Engine supports auth
|
||||
) -> Any:
|
||||
if asyncio.iscoroutinefunction(handler) or (
|
||||
callable(handler) and asyncio.iscoroutinefunction(handler.__call__) # type: ignore[operator]
|
||||
):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
from typing import cast
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from arcade.actor.core.auth import validate_token
|
||||
|
||||
security = HTTPBearer() # Authorization: Bearer <xxx>
|
||||
|
||||
|
||||
# Dependency function to validate JWT and extract API key
|
||||
# The validator function is provided by the BaseActor class
|
||||
async def get_api_key(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> str:
|
||||
jwt: str = credentials.credentials
|
||||
validation_result = validate_token(jwt)
|
||||
|
||||
if not validation_result.valid:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid token. Error: {validation_result.error}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return cast(str, validation_result.api_key)
|
||||
|
|
@ -1,63 +1,156 @@
|
|||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from arcade.core.env import settings
|
||||
|
||||
|
||||
class ApiConfig(BaseModel):
|
||||
"""
|
||||
Arcade API configuration.
|
||||
"""
|
||||
|
||||
key: str
|
||||
"""
|
||||
Arcade API key.
|
||||
"""
|
||||
secret: str
|
||||
"""
|
||||
Arcade API secret.
|
||||
"""
|
||||
|
||||
|
||||
class UserConfig(BaseModel):
|
||||
"""
|
||||
Arcade user configuration.
|
||||
"""
|
||||
|
||||
email: str | None = None
|
||||
"""
|
||||
User email.
|
||||
"""
|
||||
|
||||
|
||||
class EngineConfig(BaseModel):
|
||||
"""
|
||||
Arcade Engine configuration.
|
||||
"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""
|
||||
Arcade Engine host.
|
||||
"""
|
||||
port: int = 6901
|
||||
"""
|
||||
Arcade Engine port.
|
||||
"""
|
||||
tls: bool = False
|
||||
"""
|
||||
Whether to use TLS for the connection to Arcade Engine.
|
||||
"""
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
api_key: str | None = None
|
||||
user_email: str | None = None
|
||||
engine_key: str | None = None
|
||||
engine_host: str = "localhost"
|
||||
engine_port: str = "6901"
|
||||
"""
|
||||
Configuration for Arcade.
|
||||
"""
|
||||
|
||||
config_dir: Path = settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
|
||||
config_file: Path = config_dir / "arcade.toml"
|
||||
api: ApiConfig
|
||||
"""
|
||||
Arcade API configuration.
|
||||
"""
|
||||
user: UserConfig | None = None
|
||||
"""
|
||||
Arcade user configuration.
|
||||
"""
|
||||
engine: EngineConfig | None = None
|
||||
"""
|
||||
Arcade Engine configuration.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config_dir_path(cls) -> Path:
|
||||
"""
|
||||
Get the path to the Arcade configuration directory.
|
||||
"""
|
||||
return settings.WORK_DIR if settings.WORK_DIR else Path.home() / ".arcade"
|
||||
|
||||
@classmethod
|
||||
def get_config_file_path(cls) -> Path:
|
||||
"""
|
||||
Get the path to the Arcade configuration file.
|
||||
"""
|
||||
return cls.get_config_dir_path() / "arcade.toml"
|
||||
|
||||
@property
|
||||
def arcade_api_key(self) -> str:
|
||||
if not self.api_key:
|
||||
raise ValueError("Arcade API Key not set")
|
||||
return self.api_key
|
||||
def engine_url(self) -> str:
|
||||
"""
|
||||
Get the URL of the Arcade Engine.
|
||||
"""
|
||||
if self.engine is None:
|
||||
raise ValueError("Engine not set")
|
||||
protocol = "https" if self.engine.tls else "http"
|
||||
return f"{protocol}://{self.engine.host}:{self.engine.port}"
|
||||
|
||||
@property
|
||||
def engine_url(self, tls: bool = False) -> str:
|
||||
if tls:
|
||||
return f"https://{self.engine_host}:{self.engine_port}"
|
||||
return f"http://{self.engine_host}:{self.engine_port}"
|
||||
|
||||
@staticmethod
|
||||
def create_config_directory() -> None:
|
||||
@classmethod
|
||||
def ensure_config_dir_exists(cls) -> None:
|
||||
"""
|
||||
Create the configuration directory if it does not exist.
|
||||
"""
|
||||
config_dir = Config.config_dir
|
||||
config_dir = Config.get_config_dir_path()
|
||||
if not config_dir.exists():
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_to_file(self) -> None:
|
||||
"""
|
||||
Save the configuration to the TOML file in the configuration directory.
|
||||
"""
|
||||
self.create_config_directory()
|
||||
config_file_path = self.config_file
|
||||
with config_file_path.open("w") as config_file:
|
||||
toml.dump(self.dict(), config_file)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls) -> "Config":
|
||||
"""
|
||||
Load the configuration from the TOML file in the configuration directory.
|
||||
If no configuration file exists, create a new one with default values.
|
||||
"""
|
||||
cls.create_config_directory()
|
||||
config_file_path = cls.config_file
|
||||
if config_file_path.exists():
|
||||
with config_file_path.open("r") as config_file:
|
||||
config_data = toml.load(config_file)
|
||||
return cls(**config_data)
|
||||
return cls()
|
||||
cls.ensure_config_dir_exists()
|
||||
|
||||
config_file_path = cls.get_config_file_path()
|
||||
if not config_file_path.exists():
|
||||
# Create a file using the default configuration
|
||||
default_config = cls.model_construct(
|
||||
api=ApiConfig.model_construct(), engine=EngineConfig()
|
||||
)
|
||||
default_config.save_to_file()
|
||||
|
||||
config_data = toml.loads(config_file_path.read_text())
|
||||
|
||||
try:
|
||||
return cls(**config_data)
|
||||
except ValidationError as e:
|
||||
# Get only the errors with {type:missing} and combine them
|
||||
# into a nicely-formatted string message.
|
||||
# Any other errors without {type:missing} should just be str()ed
|
||||
missing_field_errors = [
|
||||
".".join(map(str, error["loc"]))
|
||||
for error in e.errors()
|
||||
if error["type"] == "missing"
|
||||
]
|
||||
other_errors = [str(error) for error in e.errors() if error["type"] != "missing"]
|
||||
|
||||
missing_field_errors_str = ", ".join(missing_field_errors)
|
||||
other_errors_str = "\n".join(other_errors)
|
||||
|
||||
pretty_str: str = "Invalid Arcade configuration."
|
||||
if missing_field_errors_str:
|
||||
pretty_str += f"\nMissing fields: {missing_field_errors_str}\n"
|
||||
if other_errors_str:
|
||||
pretty_str += f"\nOther errors:\n{other_errors_str}"
|
||||
|
||||
raise ValueError(pretty_str) from e
|
||||
|
||||
def save_to_file(self) -> None:
|
||||
"""
|
||||
Save the configuration to the TOML file in the configuration directory.
|
||||
"""
|
||||
Config.ensure_config_dir_exists()
|
||||
config_file_path = Config.get_config_file_path()
|
||||
config_file_path.write_text(toml.dumps(self.model_dump()))
|
||||
|
||||
|
||||
# Singleton instance of Config
|
||||
|
|
|
|||
19
arcade/poetry.lock
generated
19
arcade/poetry.lock
generated
|
|
@ -1124,6 +1124,23 @@ files = [
|
|||
[package.extras]
|
||||
windows-terminal = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.8.0"
|
||||
description = "JSON Web Token implementation in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
|
||||
{file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
crypto = ["cryptography (>=3.4.0)"]
|
||||
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
|
||||
docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
|
||||
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pymdown-extensions"
|
||||
version = "10.8.1"
|
||||
|
|
@ -1694,4 +1711,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "8ae884652ca236023f88f4e4074cb831259b562ae634c62d298c39d1c47c5999"
|
||||
content-hash = "36b615b9076322778f3195849116d9248c088df80410c1424164b89ee534e93e"
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ tomlkit = "^0.12.4"
|
|||
requests = "^2.26.0"
|
||||
|
||||
openai = "^1.36.0"
|
||||
pyjwt = "^2.8.0"
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.2.0"
|
||||
pytest-cov = "^4.0.0"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from arcade.actor.fastapi.actor import FastAPIActor
|
|||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from openai import AsyncOpenAI
|
||||
from arcade_example_nate.tools import arithmetic
|
||||
|
||||
from tools import arithmetic
|
||||
|
||||
|
||||
client = AsyncOpenAI(base_url="http://localhost:6901")
|
||||
|
|
@ -30,6 +31,7 @@ async def chat(request: ChatRequest):
|
|||
],
|
||||
model="gpt-4o-mini",
|
||||
max_tokens=150,
|
||||
tools=["add", "subtract", "multiply", "divide", "sqrt"],
|
||||
tool_choice="execute",
|
||||
)
|
||||
chat_completion = raw_response.parse()
|
||||
|
|
|
|||
|
|
@ -14,6 +14,16 @@ def add(
|
|||
return a + b
|
||||
|
||||
|
||||
@tool
|
||||
def subtract(
|
||||
a: Annotated[int, "The first number"], b: Annotated[int, "The second number"]
|
||||
) -> Annotated[int, "The difference of the two numbers"]:
|
||||
"""
|
||||
Subtract two numbers
|
||||
"""
|
||||
return a - b
|
||||
|
||||
|
||||
@tool
|
||||
def multiply(
|
||||
a: Annotated[int, "The first number"], b: Annotated[int, "The second number"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue