- Add Prometheus metrics for marketplace API throughput and error rates with new dashboard panels - Implement confidential transaction models with encryption support and access control - Add key management system with registration, rotation, and audit logging - Create services and registry routers for service discovery and management - Integrate ZK proof generation for privacy-preserving receipts - Add metrics instru
208 lines
6.5 KiB
Python
208 lines
6.5 KiB
Python
"""
|
|
Authentication handlers for AITBC Enterprise Connectors
|
|
"""
|
|
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, Optional
|
|
from datetime import datetime, timedelta
|
|
|
|
from .core import ConnectorConfig
|
|
from .exceptions import AuthenticationError
|
|
|
|
|
|
class AuthHandler(ABC):
|
|
"""Abstract base class for authentication handlers"""
|
|
|
|
@abstractmethod
|
|
async def get_headers(self) -> Dict[str, str]:
|
|
"""Get authentication headers"""
|
|
pass
|
|
|
|
|
|
class BearerAuthHandler(AuthHandler):
|
|
"""Bearer token authentication"""
|
|
|
|
def __init__(self, config: ConnectorConfig):
|
|
self.api_key = config.api_key
|
|
|
|
async def get_headers(self) -> Dict[str, str]:
|
|
"""Get Bearer token headers"""
|
|
return {
|
|
"Authorization": f"Bearer {self.api_key}"
|
|
}
|
|
|
|
|
|
class BasicAuthHandler(AuthHandler):
|
|
"""Basic authentication"""
|
|
|
|
def __init__(self, config: ConnectorConfig):
|
|
self.username = config.auth_config.get("username")
|
|
self.password = config.auth_config.get("password")
|
|
|
|
async def get_headers(self) -> Dict[str, str]:
|
|
"""Get Basic auth headers"""
|
|
if not self.username or not self.password:
|
|
raise AuthenticationError("Username and password required for Basic auth")
|
|
|
|
credentials = f"{self.username}:{self.password}"
|
|
encoded = base64.b64encode(credentials.encode()).decode()
|
|
|
|
return {
|
|
"Authorization": f"Basic {encoded}"
|
|
}
|
|
|
|
|
|
class APIKeyAuthHandler(AuthHandler):
|
|
"""API key authentication (custom header)"""
|
|
|
|
def __init__(self, config: ConnectorConfig):
|
|
self.api_key = config.api_key
|
|
self.header_name = config.auth_config.get("header_name", "X-API-Key")
|
|
|
|
async def get_headers(self) -> Dict[str, str]:
|
|
"""Get API key headers"""
|
|
return {
|
|
self.header_name: self.api_key
|
|
}
|
|
|
|
|
|
class HMACAuthHandler(AuthHandler):
|
|
"""HMAC signature authentication"""
|
|
|
|
def __init__(self, config: ConnectorConfig):
|
|
self.api_key = config.api_key
|
|
self.secret = config.auth_config.get("secret")
|
|
self.algorithm = config.auth_config.get("algorithm", "sha256")
|
|
|
|
async def get_headers(self) -> Dict[str, str]:
|
|
"""Get HMAC signature headers"""
|
|
if not self.secret:
|
|
raise AuthenticationError("Secret required for HMAC auth")
|
|
|
|
timestamp = str(int(time.time()))
|
|
message = f"{timestamp}:{self.api_key}"
|
|
|
|
signature = hmac.new(
|
|
self.secret.encode(),
|
|
message.encode(),
|
|
getattr(hashlib, self.algorithm)
|
|
).hexdigest()
|
|
|
|
return {
|
|
"X-API-Key": self.api_key,
|
|
"X-Timestamp": timestamp,
|
|
"X-Signature": signature
|
|
}
|
|
|
|
|
|
class OAuth2Handler(AuthHandler):
|
|
"""OAuth 2.0 authentication"""
|
|
|
|
def __init__(self, config: ConnectorConfig):
|
|
self.client_id = config.auth_config.get("client_id")
|
|
self.client_secret = config.auth_config.get("client_secret")
|
|
self.token_url = config.auth_config.get("token_url")
|
|
self.scope = config.auth_config.get("scope", "")
|
|
|
|
self._access_token = None
|
|
self._refresh_token = None
|
|
self._expires_at = None
|
|
|
|
async def get_headers(self) -> Dict[str, str]:
|
|
"""Get OAuth 2.0 headers"""
|
|
if not self._is_token_valid():
|
|
await self._refresh_access_token()
|
|
|
|
return {
|
|
"Authorization": f"Bearer {self._access_token}"
|
|
}
|
|
|
|
def _is_token_valid(self) -> bool:
|
|
"""Check if access token is valid"""
|
|
if not self._access_token or not self._expires_at:
|
|
return False
|
|
|
|
# Refresh 5 minutes before expiry
|
|
return datetime.utcnow() < (self._expires_at - timedelta(minutes=5))
|
|
|
|
async def _refresh_access_token(self):
|
|
"""Refresh OAuth 2.0 access token"""
|
|
import aiohttp
|
|
|
|
data = {
|
|
"grant_type": "client_credentials",
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
"scope": self.scope
|
|
}
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(self.token_url, data=data) as response:
|
|
if response.status != 200:
|
|
raise AuthenticationError(f"OAuth token request failed: {response.status}")
|
|
|
|
token_data = await response.json()
|
|
|
|
self._access_token = token_data["access_token"]
|
|
self._refresh_token = token_data.get("refresh_token")
|
|
|
|
expires_in = token_data.get("expires_in", 3600)
|
|
self._expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
|
|
|
|
|
class CertificateAuthHandler(AuthHandler):
|
|
"""Certificate-based authentication"""
|
|
|
|
def __init__(self, config: ConnectorConfig):
|
|
self.cert_path = config.auth_config.get("cert_path")
|
|
self.key_path = config.auth_config.get("key_path")
|
|
self.passphrase = config.auth_config.get("passphrase")
|
|
|
|
async def get_headers(self) -> Dict[str, str]:
|
|
"""Certificate auth uses client cert, not headers"""
|
|
return {}
|
|
|
|
def get_ssl_context(self):
|
|
"""Get SSL context for certificate authentication"""
|
|
import ssl
|
|
|
|
context = ssl.create_default_context()
|
|
|
|
if self.cert_path and self.key_path:
|
|
context.load_cert_chain(
|
|
self.cert_path,
|
|
self.key_path,
|
|
password=self.passphrase
|
|
)
|
|
|
|
return context
|
|
|
|
|
|
class AuthHandlerFactory:
|
|
"""Factory for creating authentication handlers"""
|
|
|
|
@staticmethod
|
|
def create(config: ConnectorConfig) -> AuthHandler:
|
|
"""Create appropriate auth handler based on config"""
|
|
auth_type = config.auth_type.lower()
|
|
|
|
if auth_type == "bearer":
|
|
return BearerAuthHandler(config)
|
|
elif auth_type == "basic":
|
|
return BasicAuthHandler(config)
|
|
elif auth_type == "api_key":
|
|
return APIKeyAuthHandler(config)
|
|
elif auth_type == "hmac":
|
|
return HMACAuthHandler(config)
|
|
elif auth_type == "oauth2":
|
|
return OAuth2Handler(config)
|
|
elif auth_type == "certificate":
|
|
return CertificateAuthHandler(config)
|
|
else:
|
|
raise AuthenticationError(f"Unsupported auth type: {auth_type}")
|