feat: add marketplace metrics, privacy features, and service registry endpoints
- Add Prometheus metrics for marketplace API throughput and error rates with new dashboard panels - Implement confidential transaction models with encryption support and access control - Add key management system with registration, rotation, and audit logging - Create services and registry routers for service discovery and management - Integrate ZK proof generation for privacy-preserving receipts - Add metrics instru
This commit is contained in:
362
apps/coordinator-api/src/app/services/access_control.py
Normal file
362
apps/coordinator-api/src/app/services/access_control.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Access control service for confidential transactions
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
import re
|
||||
|
||||
from ..models import ConfidentialAccessRequest, ConfidentialAccessLog
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AccessPurpose(str, Enum):
|
||||
"""Standard access purposes"""
|
||||
SETTLEMENT = "settlement"
|
||||
AUDIT = "audit"
|
||||
COMPLIANCE = "compliance"
|
||||
DISPUTE = "dispute"
|
||||
SUPPORT = "support"
|
||||
REPORTING = "reporting"
|
||||
|
||||
|
||||
class AccessLevel(str, Enum):
|
||||
"""Access levels for confidential data"""
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class ParticipantRole(str, Enum):
|
||||
"""Roles for transaction participants"""
|
||||
CLIENT = "client"
|
||||
MINER = "miner"
|
||||
COORDINATOR = "coordinator"
|
||||
AUDITOR = "auditor"
|
||||
REGULATOR = "regulator"
|
||||
|
||||
|
||||
class PolicyStore:
|
||||
"""Storage for access control policies"""
|
||||
|
||||
def __init__(self):
|
||||
self._policies: Dict[str, Dict] = {}
|
||||
self._role_permissions: Dict[ParticipantRole, Set[str]] = {
|
||||
ParticipantRole.CLIENT: {"read_own", "settlement_own"},
|
||||
ParticipantRole.MINER: {"read_assigned", "settlement_assigned"},
|
||||
ParticipantRole.COORDINATOR: {"read_all", "admin_all"},
|
||||
ParticipantRole.AUDITOR: {"read_all", "audit_all"},
|
||||
ParticipantRole.REGULATOR: {"read_all", "compliance_all"}
|
||||
}
|
||||
self._load_default_policies()
|
||||
|
||||
def _load_default_policies(self):
|
||||
"""Load default access policies"""
|
||||
# Client can access their own transactions
|
||||
self._policies["client_own_data"] = {
|
||||
"participants": ["client"],
|
||||
"conditions": {
|
||||
"transaction_client_id": "{requester}",
|
||||
"purpose": ["settlement", "dispute", "support"]
|
||||
},
|
||||
"access_level": AccessLevel.READ,
|
||||
"time_restrictions": None
|
||||
}
|
||||
|
||||
# Miner can access assigned transactions
|
||||
self._policies["miner_assigned_data"] = {
|
||||
"participants": ["miner"],
|
||||
"conditions": {
|
||||
"transaction_miner_id": "{requester}",
|
||||
"purpose": ["settlement"]
|
||||
},
|
||||
"access_level": AccessLevel.READ,
|
||||
"time_restrictions": None
|
||||
}
|
||||
|
||||
# Coordinator has full access
|
||||
self._policies["coordinator_full"] = {
|
||||
"participants": ["coordinator"],
|
||||
"conditions": {},
|
||||
"access_level": AccessLevel.ADMIN,
|
||||
"time_restrictions": None
|
||||
}
|
||||
|
||||
# Auditor access for compliance
|
||||
self._policies["auditor_compliance"] = {
|
||||
"participants": ["auditor", "regulator"],
|
||||
"conditions": {
|
||||
"purpose": ["audit", "compliance"]
|
||||
},
|
||||
"access_level": AccessLevel.READ,
|
||||
"time_restrictions": {
|
||||
"business_hours_only": True,
|
||||
"retention_days": 2555 # 7 years
|
||||
}
|
||||
}
|
||||
|
||||
def get_policy(self, policy_id: str) -> Optional[Dict]:
|
||||
"""Get access policy by ID"""
|
||||
return self._policies.get(policy_id)
|
||||
|
||||
def list_policies(self) -> List[str]:
|
||||
"""List all policy IDs"""
|
||||
return list(self._policies.keys())
|
||||
|
||||
def add_policy(self, policy_id: str, policy: Dict):
|
||||
"""Add new access policy"""
|
||||
self._policies[policy_id] = policy
|
||||
|
||||
def get_role_permissions(self, role: ParticipantRole) -> Set[str]:
|
||||
"""Get permissions for a role"""
|
||||
return self._role_permissions.get(role, set())
|
||||
|
||||
|
||||
class AccessController:
|
||||
"""Controls access to confidential transaction data"""
|
||||
|
||||
def __init__(self, policy_store: PolicyStore):
|
||||
self.policy_store = policy_store
|
||||
self._access_cache: Dict[str, Dict] = {}
|
||||
self._cache_ttl = timedelta(minutes=5)
|
||||
|
||||
def verify_access(self, request: ConfidentialAccessRequest) -> bool:
|
||||
"""Verify if requester has access rights"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = self._get_cache_key(request)
|
||||
cached_result = self._get_cached_result(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result["allowed"]
|
||||
|
||||
# Get participant info
|
||||
participant_info = self._get_participant_info(request.requester)
|
||||
if not participant_info:
|
||||
logger.warning(f"Unknown participant: {request.requester}")
|
||||
return False
|
||||
|
||||
# Check role-based permissions
|
||||
role = participant_info.get("role")
|
||||
if not self._check_role_permissions(role, request):
|
||||
return False
|
||||
|
||||
# Check transaction-specific policies
|
||||
transaction = self._get_transaction(request.transaction_id)
|
||||
if not transaction:
|
||||
logger.warning(f"Transaction not found: {request.transaction_id}")
|
||||
return False
|
||||
|
||||
# Apply access policies
|
||||
allowed = self._apply_policies(request, participant_info, transaction)
|
||||
|
||||
# Cache result
|
||||
self._cache_result(cache_key, allowed)
|
||||
|
||||
return allowed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Access verification failed: {e}")
|
||||
return False
|
||||
|
||||
def _check_role_permissions(self, role: str, request: ConfidentialAccessRequest) -> bool:
|
||||
"""Check if role grants access for this purpose"""
|
||||
try:
|
||||
participant_role = ParticipantRole(role.lower())
|
||||
permissions = self.policy_store.get_role_permissions(participant_role)
|
||||
|
||||
# Check purpose-based permissions
|
||||
if request.purpose == "settlement":
|
||||
return "settlement" in permissions or "settlement_own" in permissions
|
||||
elif request.purpose == "audit":
|
||||
return "audit" in permissions or "audit_all" in permissions
|
||||
elif request.purpose == "compliance":
|
||||
return "compliance" in permissions or "compliance_all" in permissions
|
||||
elif request.purpose == "dispute":
|
||||
return "dispute" in permissions or "read_own" in permissions
|
||||
elif request.purpose == "support":
|
||||
return "support" in permissions or "read_all" in permissions
|
||||
else:
|
||||
return "read" in permissions or "read_all" in permissions
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid role: {role}")
|
||||
return False
|
||||
|
||||
def _apply_policies(
|
||||
self,
|
||||
request: ConfidentialAccessRequest,
|
||||
participant_info: Dict,
|
||||
transaction: Dict
|
||||
) -> bool:
|
||||
"""Apply access policies to request"""
|
||||
# Check if participant is in transaction participants list
|
||||
if request.requester not in transaction.get("participants", []):
|
||||
# Only coordinators, auditors, and regulators can access non-participant data
|
||||
role = participant_info.get("role", "").lower()
|
||||
if role not in ["coordinator", "auditor", "regulator"]:
|
||||
return False
|
||||
|
||||
# Check time-based restrictions
|
||||
if not self._check_time_restrictions(request.purpose, participant_info.get("role")):
|
||||
return False
|
||||
|
||||
# Check business hours for auditors
|
||||
if participant_info.get("role") == "auditor" and not self._is_business_hours():
|
||||
return False
|
||||
|
||||
# Check retention periods
|
||||
if not self._check_retention_period(transaction, participant_info.get("role")):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_time_restrictions(self, purpose: str, role: Optional[str]) -> bool:
|
||||
"""Check time-based access restrictions"""
|
||||
# No restrictions for settlement and dispute
|
||||
if purpose in ["settlement", "dispute"]:
|
||||
return True
|
||||
|
||||
# Audit and compliance only during business hours for non-coordinators
|
||||
if purpose in ["audit", "compliance"] and role not in ["coordinator"]:
|
||||
return self._is_business_hours()
|
||||
|
||||
return True
|
||||
|
||||
def _is_business_hours(self) -> bool:
|
||||
"""Check if current time is within business hours"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Monday-Friday, 9 AM - 5 PM UTC
|
||||
if now.weekday() >= 5: # Weekend
|
||||
return False
|
||||
|
||||
if 9 <= now.hour < 17:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _check_retention_period(self, transaction: Dict, role: Optional[str]) -> bool:
|
||||
"""Check if data is within retention period for role"""
|
||||
transaction_date = transaction.get("timestamp", datetime.utcnow())
|
||||
|
||||
# Different retention periods for different roles
|
||||
if role == "regulator":
|
||||
retention_days = 2555 # 7 years
|
||||
elif role == "auditor":
|
||||
retention_days = 1825 # 5 years
|
||||
elif role == "coordinator":
|
||||
retention_days = 3650 # 10 years
|
||||
else:
|
||||
retention_days = 365 # 1 year
|
||||
|
||||
expiry_date = transaction_date + timedelta(days=retention_days)
|
||||
|
||||
return datetime.utcnow() <= expiry_date
|
||||
|
||||
def _get_participant_info(self, participant_id: str) -> Optional[Dict]:
|
||||
"""Get participant information"""
|
||||
# In production, query from database
|
||||
# For now, return mock data
|
||||
if participant_id.startswith("client-"):
|
||||
return {"id": participant_id, "role": "client", "active": True}
|
||||
elif participant_id.startswith("miner-"):
|
||||
return {"id": participant_id, "role": "miner", "active": True}
|
||||
elif participant_id.startswith("coordinator-"):
|
||||
return {"id": participant_id, "role": "coordinator", "active": True}
|
||||
elif participant_id.startswith("auditor-"):
|
||||
return {"id": participant_id, "role": "auditor", "active": True}
|
||||
elif participant_id.startswith("regulator-"):
|
||||
return {"id": participant_id, "role": "regulator", "active": True}
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_transaction(self, transaction_id: str) -> Optional[Dict]:
|
||||
"""Get transaction information"""
|
||||
# In production, query from database
|
||||
# For now, return mock data
|
||||
return {
|
||||
"transaction_id": transaction_id,
|
||||
"participants": ["client-456", "miner-789"],
|
||||
"timestamp": datetime.utcnow(),
|
||||
"status": "completed"
|
||||
}
|
||||
|
||||
def _get_cache_key(self, request: ConfidentialAccessRequest) -> str:
|
||||
"""Generate cache key for access request"""
|
||||
return f"{request.requester}:{request.transaction_id}:{request.purpose}"
|
||||
|
||||
def _get_cached_result(self, cache_key: str) -> Optional[Dict]:
|
||||
"""Get cached access result"""
|
||||
if cache_key in self._access_cache:
|
||||
cached = self._access_cache[cache_key]
|
||||
if datetime.utcnow() - cached["timestamp"] < self._cache_ttl:
|
||||
return cached
|
||||
else:
|
||||
del self._access_cache[cache_key]
|
||||
return None
|
||||
|
||||
def _cache_result(self, cache_key: str, allowed: bool):
|
||||
"""Cache access result"""
|
||||
self._access_cache[cache_key] = {
|
||||
"allowed": allowed,
|
||||
"timestamp": datetime.utcnow()
|
||||
}
|
||||
|
||||
def create_access_policy(
|
||||
self,
|
||||
name: str,
|
||||
participants: List[str],
|
||||
conditions: Dict[str, Any],
|
||||
access_level: AccessLevel
|
||||
) -> str:
|
||||
"""Create a new access policy"""
|
||||
policy_id = f"policy_{datetime.utcnow().timestamp()}"
|
||||
|
||||
policy = {
|
||||
"participants": participants,
|
||||
"conditions": conditions,
|
||||
"access_level": access_level,
|
||||
"time_restrictions": conditions.get("time_restrictions"),
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
self.policy_store.add_policy(policy_id, policy)
|
||||
logger.info(f"Created access policy: {policy_id}")
|
||||
|
||||
return policy_id
|
||||
|
||||
def revoke_access(self, participant_id: str, transaction_id: Optional[str] = None):
|
||||
"""Revoke access for participant"""
|
||||
# In production, update database
|
||||
# For now, clear cache
|
||||
keys_to_remove = []
|
||||
for key in self._access_cache:
|
||||
if key.startswith(f"{participant_id}:"):
|
||||
if transaction_id is None or key.split(":")[1] == transaction_id:
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._access_cache[key]
|
||||
|
||||
logger.info(f"Revoked access for participant: {participant_id}")
|
||||
|
||||
def get_access_summary(self, participant_id: str) -> Dict:
|
||||
"""Get summary of participant's access rights"""
|
||||
participant_info = self._get_participant_info(participant_id)
|
||||
if not participant_info:
|
||||
return {"error": "Participant not found"}
|
||||
|
||||
role = participant_info.get("role")
|
||||
permissions = self.policy_store.get_role_permissions(ParticipantRole(role))
|
||||
|
||||
return {
|
||||
"participant_id": participant_id,
|
||||
"role": role,
|
||||
"permissions": list(permissions),
|
||||
"active": participant_info.get("active", False)
|
||||
}
|
||||
532
apps/coordinator-api/src/app/services/audit_logging.py
Normal file
532
apps/coordinator-api/src/app/services/audit_logging.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""
|
||||
Audit logging service for privacy compliance
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import gzip
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from ..models import ConfidentialAccessLog
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditEvent:
|
||||
"""Structured audit event"""
|
||||
event_id: str
|
||||
timestamp: datetime
|
||||
event_type: str
|
||||
participant_id: str
|
||||
transaction_id: Optional[str]
|
||||
action: str
|
||||
resource: str
|
||||
outcome: str
|
||||
details: Dict[str, Any]
|
||||
ip_address: Optional[str]
|
||||
user_agent: Optional[str]
|
||||
authorization: Optional[str]
|
||||
signature: Optional[str]
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""Tamper-evident audit logging for privacy compliance"""
|
||||
|
||||
def __init__(self, log_dir: str = "/var/log/aitbc/audit"):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Current log file
|
||||
self.current_file = None
|
||||
self.current_hash = None
|
||||
|
||||
# Async writer task
|
||||
self.write_queue = asyncio.Queue(maxsize=10000)
|
||||
self.writer_task = None
|
||||
|
||||
# Chain of hashes for integrity
|
||||
self.chain_hash = self._load_chain_hash()
|
||||
|
||||
async def start(self):
|
||||
"""Start the background writer task"""
|
||||
if self.writer_task is None:
|
||||
self.writer_task = asyncio.create_task(self._background_writer())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the background writer task"""
|
||||
if self.writer_task:
|
||||
self.writer_task.cancel()
|
||||
try:
|
||||
await self.writer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.writer_task = None
|
||||
|
||||
async def log_access(
|
||||
self,
|
||||
participant_id: str,
|
||||
transaction_id: Optional[str],
|
||||
action: str,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
authorization: Optional[str] = None
|
||||
):
|
||||
"""Log access to confidential data"""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
timestamp=datetime.utcnow(),
|
||||
event_type="access",
|
||||
participant_id=participant_id,
|
||||
transaction_id=transaction_id,
|
||||
action=action,
|
||||
resource="confidential_transaction",
|
||||
outcome=outcome,
|
||||
details=details or {},
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
authorization=authorization,
|
||||
signature=None
|
||||
)
|
||||
|
||||
# Add signature for tamper-evidence
|
||||
event.signature = self._sign_event(event)
|
||||
|
||||
# Queue for writing
|
||||
await self.write_queue.put(event)
|
||||
|
||||
async def log_key_operation(
|
||||
self,
|
||||
participant_id: str,
|
||||
operation: str,
|
||||
key_version: int,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log key management operations"""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
timestamp=datetime.utcnow(),
|
||||
event_type="key_operation",
|
||||
participant_id=participant_id,
|
||||
transaction_id=None,
|
||||
action=operation,
|
||||
resource="encryption_key",
|
||||
outcome=outcome,
|
||||
details={**(details or {}), "key_version": key_version},
|
||||
ip_address=None,
|
||||
user_agent=None,
|
||||
authorization=None,
|
||||
signature=None
|
||||
)
|
||||
|
||||
event.signature = self._sign_event(event)
|
||||
await self.write_queue.put(event)
|
||||
|
||||
async def log_policy_change(
|
||||
self,
|
||||
participant_id: str,
|
||||
policy_id: str,
|
||||
change_type: str,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log access policy changes"""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
timestamp=datetime.utcnow(),
|
||||
event_type="policy_change",
|
||||
participant_id=participant_id,
|
||||
transaction_id=None,
|
||||
action=change_type,
|
||||
resource="access_policy",
|
||||
outcome=outcome,
|
||||
details={**(details or {}), "policy_id": policy_id},
|
||||
ip_address=None,
|
||||
user_agent=None,
|
||||
authorization=None,
|
||||
signature=None
|
||||
)
|
||||
|
||||
event.signature = self._sign_event(event)
|
||||
await self.write_queue.put(event)
|
||||
|
||||
def query_logs(
|
||||
self,
|
||||
participant_id: Optional[str] = None,
|
||||
transaction_id: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[AuditEvent]:
|
||||
"""Query audit logs"""
|
||||
results = []
|
||||
|
||||
# Get list of log files to search
|
||||
log_files = self._get_log_files(start_time, end_time)
|
||||
|
||||
for log_file in log_files:
|
||||
try:
|
||||
# Read and decompress if needed
|
||||
if log_file.suffix == ".gz":
|
||||
with gzip.open(log_file, "rt") as f:
|
||||
for line in f:
|
||||
event = self._parse_log_line(line.strip())
|
||||
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
|
||||
results.append(event)
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
else:
|
||||
with open(log_file, "r") as f:
|
||||
for line in f:
|
||||
event = self._parse_log_line(line.strip())
|
||||
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
|
||||
results.append(event)
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read log file {log_file}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
results.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def verify_integrity(self, start_date: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Verify integrity of audit logs"""
|
||||
if start_date is None:
|
||||
start_date = datetime.utcnow() - timedelta(days=30)
|
||||
|
||||
results = {
|
||||
"verified_files": 0,
|
||||
"total_files": 0,
|
||||
"integrity_violations": [],
|
||||
"chain_valid": True
|
||||
}
|
||||
|
||||
log_files = self._get_log_files(start_date)
|
||||
|
||||
for log_file in log_files:
|
||||
results["total_files"] += 1
|
||||
|
||||
try:
|
||||
# Verify file hash
|
||||
file_hash = self._calculate_file_hash(log_file)
|
||||
stored_hash = self._get_stored_hash(log_file)
|
||||
|
||||
if file_hash != stored_hash:
|
||||
results["integrity_violations"].append({
|
||||
"file": str(log_file),
|
||||
"expected": stored_hash,
|
||||
"actual": file_hash
|
||||
})
|
||||
results["chain_valid"] = False
|
||||
else:
|
||||
results["verified_files"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify {log_file}: {e}")
|
||||
results["integrity_violations"].append({
|
||||
"file": str(log_file),
|
||||
"error": str(e)
|
||||
})
|
||||
results["chain_valid"] = False
|
||||
|
||||
return results
|
||||
|
||||
def export_logs(
|
||||
self,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
format: str = "json",
|
||||
include_signatures: bool = True
|
||||
) -> str:
|
||||
"""Export audit logs for compliance reporting"""
|
||||
events = self.query_logs(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=10000
|
||||
)
|
||||
|
||||
if format == "json":
|
||||
export_data = {
|
||||
"export_metadata": {
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": end_time.isoformat(),
|
||||
"event_count": len(events),
|
||||
"exported_at": datetime.utcnow().isoformat(),
|
||||
"include_signatures": include_signatures
|
||||
},
|
||||
"events": []
|
||||
}
|
||||
|
||||
for event in events:
|
||||
event_dict = asdict(event)
|
||||
event_dict["timestamp"] = event.timestamp.isoformat()
|
||||
|
||||
if not include_signatures:
|
||||
event_dict.pop("signature", None)
|
||||
|
||||
export_data["events"].append(event_dict)
|
||||
|
||||
return json.dumps(export_data, indent=2)
|
||||
|
||||
elif format == "csv":
|
||||
import csv
|
||||
import io
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Header
|
||||
header = [
|
||||
"event_id", "timestamp", "event_type", "participant_id",
|
||||
"transaction_id", "action", "resource", "outcome",
|
||||
"ip_address", "user_agent"
|
||||
]
|
||||
if include_signatures:
|
||||
header.append("signature")
|
||||
writer.writerow(header)
|
||||
|
||||
# Events
|
||||
for event in events:
|
||||
row = [
|
||||
event.event_id,
|
||||
event.timestamp.isoformat(),
|
||||
event.event_type,
|
||||
event.participant_id,
|
||||
event.transaction_id,
|
||||
event.action,
|
||||
event.resource,
|
||||
event.outcome,
|
||||
event.ip_address,
|
||||
event.user_agent
|
||||
]
|
||||
if include_signatures:
|
||||
row.append(event.signature)
|
||||
writer.writerow(row)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported export format: {format}")
|
||||
|
||||
async def _background_writer(self):
|
||||
"""Background task for writing audit events"""
|
||||
while True:
|
||||
try:
|
||||
# Get batch of events
|
||||
events = []
|
||||
while len(events) < 100:
|
||||
try:
|
||||
# Use asyncio.wait_for for timeout
|
||||
event = await asyncio.wait_for(
|
||||
self.write_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
events.append(event)
|
||||
except asyncio.TimeoutError:
|
||||
if events:
|
||||
break
|
||||
continue
|
||||
|
||||
# Write events
|
||||
if events:
|
||||
self._write_events(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background writer error: {e}")
|
||||
# Brief pause to avoid error loops
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def _write_events(self, events: List[AuditEvent]):
|
||||
"""Write events to current log file"""
|
||||
try:
|
||||
self._rotate_if_needed()
|
||||
|
||||
with open(self.current_file, "a") as f:
|
||||
for event in events:
|
||||
# Convert to JSON line
|
||||
event_dict = asdict(event)
|
||||
event_dict["timestamp"] = event.timestamp.isoformat()
|
||||
|
||||
# Write with signature
|
||||
line = json.dumps(event_dict, separators=(",", ":")) + "\n"
|
||||
f.write(line)
|
||||
f.flush()
|
||||
|
||||
# Update chain hash
|
||||
self._update_chain_hash(events[-1])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audit events: {e}")
|
||||
|
||||
def _rotate_if_needed(self):
|
||||
"""Rotate log file if needed"""
|
||||
now = datetime.utcnow()
|
||||
today = now.date()
|
||||
|
||||
# Check if we need a new file
|
||||
if self.current_file is None:
|
||||
self._new_log_file(today)
|
||||
else:
|
||||
file_date = datetime.fromisoformat(
|
||||
self.current_file.stem.split("_")[1]
|
||||
).date()
|
||||
|
||||
if file_date != today:
|
||||
self._new_log_file(today)
|
||||
|
||||
def _new_log_file(self, date):
|
||||
"""Create new log file for date"""
|
||||
filename = f"audit_{date.isoformat()}.log"
|
||||
self.current_file = self.log_dir / filename
|
||||
|
||||
# Write header with metadata
|
||||
if not self.current_file.exists():
|
||||
header = {
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"version": "1.0",
|
||||
"format": "jsonl",
|
||||
"previous_hash": self.chain_hash
|
||||
}
|
||||
|
||||
with open(self.current_file, "w") as f:
|
||||
f.write(f"# {json.dumps(header)}\n")
|
||||
|
||||
def _generate_event_id(self) -> str:
|
||||
"""Generate unique event ID"""
|
||||
return f"evt_{datetime.utcnow().timestamp()}_{os.urandom(4).hex()}"
|
||||
|
||||
def _sign_event(self, event: AuditEvent) -> str:
|
||||
"""Sign event for tamper-evidence"""
|
||||
# Create canonical representation
|
||||
event_data = {
|
||||
"event_id": event.event_id,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"participant_id": event.participant_id,
|
||||
"action": event.action,
|
||||
"outcome": event.outcome
|
||||
}
|
||||
|
||||
# Hash with previous chain hash
|
||||
data = json.dumps(event_data, separators=(",", ":"), sort_keys=True)
|
||||
combined = f"{self.chain_hash}:{data}".encode()
|
||||
|
||||
return hashlib.sha256(combined).hexdigest()
|
||||
|
||||
def _update_chain_hash(self, last_event: AuditEvent):
|
||||
"""Update chain hash with new event"""
|
||||
self.chain_hash = last_event.signature or self.chain_hash
|
||||
|
||||
# Store chain hash for integrity checking
|
||||
chain_file = self.log_dir / "chain.hash"
|
||||
with open(chain_file, "w") as f:
|
||||
f.write(self.chain_hash)
|
||||
|
||||
def _load_chain_hash(self) -> str:
|
||||
"""Load previous chain hash"""
|
||||
chain_file = self.log_dir / "chain.hash"
|
||||
if chain_file.exists():
|
||||
with open(chain_file, "r") as f:
|
||||
return f.read().strip()
|
||||
return "0" * 64 # Initial hash
|
||||
|
||||
def _get_log_files(self, start_time: Optional[datetime], end_time: Optional[datetime]) -> List[Path]:
|
||||
"""Get list of log files to search"""
|
||||
files = []
|
||||
|
||||
for file in self.log_dir.glob("audit_*.log*"):
|
||||
try:
|
||||
# Extract date from filename
|
||||
date_str = file.stem.split("_")[1]
|
||||
file_date = datetime.fromisoformat(date_str).date()
|
||||
|
||||
# Check if file is in range
|
||||
file_start = datetime.combine(file_date, datetime.min.time())
|
||||
file_end = file_start + timedelta(days=1)
|
||||
|
||||
if (not start_time or file_end >= start_time) and \
|
||||
(not end_time or file_start <= end_time):
|
||||
files.append(file)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return sorted(files)
|
||||
|
||||
def _parse_log_line(self, line: str) -> Optional[AuditEvent]:
|
||||
"""Parse log line into event"""
|
||||
if line.startswith("#"):
|
||||
return None # Skip header
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
|
||||
return AuditEvent(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse log line: {e}")
|
||||
return None
|
||||
|
||||
def _matches_query(
|
||||
self,
|
||||
event: Optional[AuditEvent],
|
||||
participant_id: Optional[str],
|
||||
transaction_id: Optional[str],
|
||||
event_type: Optional[str],
|
||||
start_time: Optional[datetime],
|
||||
end_time: Optional[datetime]
|
||||
) -> bool:
|
||||
"""Check if event matches query criteria"""
|
||||
if not event:
|
||||
return False
|
||||
|
||||
if participant_id and event.participant_id != participant_id:
|
||||
return False
|
||||
|
||||
if transaction_id and event.transaction_id != transaction_id:
|
||||
return False
|
||||
|
||||
if event_type and event.event_type != event_type:
|
||||
return False
|
||||
|
||||
if start_time and event.timestamp < start_time:
|
||||
return False
|
||||
|
||||
if end_time and event.timestamp > end_time:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _calculate_file_hash(self, file_path: Path) -> str:
|
||||
"""Calculate SHA-256 hash of file"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
def _get_stored_hash(self, file_path: Path) -> str:
|
||||
"""Get stored hash for file"""
|
||||
hash_file = file_path.with_suffix(".hash")
|
||||
if hash_file.exists():
|
||||
with open(hash_file, "r") as f:
|
||||
return f.read().strip()
|
||||
return ""
|
||||
|
||||
|
||||
# Global audit logger instance
|
||||
audit_logger = AuditLogger()
|
||||
349
apps/coordinator-api/src/app/services/encryption.py
Normal file
349
apps/coordinator-api/src/app/services/encryption.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
Encryption service for confidential transactions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
|
||||
|
||||
from ..models import ConfidentialTransaction, AccessLog
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EncryptedData:
|
||||
"""Container for encrypted data and keys"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ciphertext: bytes,
|
||||
encrypted_keys: Dict[str, bytes],
|
||||
algorithm: str = "AES-256-GCM+X25519",
|
||||
nonce: Optional[bytes] = None,
|
||||
tag: Optional[bytes] = None
|
||||
):
|
||||
self.ciphertext = ciphertext
|
||||
self.encrypted_keys = encrypted_keys
|
||||
self.algorithm = algorithm
|
||||
self.nonce = nonce
|
||||
self.tag = tag
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
"ciphertext": base64.b64encode(self.ciphertext).decode(),
|
||||
"encrypted_keys": {
|
||||
participant: base64.b64encode(key).decode()
|
||||
for participant, key in self.encrypted_keys.items()
|
||||
},
|
||||
"algorithm": self.algorithm,
|
||||
"nonce": base64.b64encode(self.nonce).decode() if self.nonce else None,
|
||||
"tag": base64.b64encode(self.tag).decode() if self.tag else None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "EncryptedData":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
ciphertext=base64.b64decode(data["ciphertext"]),
|
||||
encrypted_keys={
|
||||
participant: base64.b64decode(key)
|
||||
for participant, key in data["encrypted_keys"].items()
|
||||
},
|
||||
algorithm=data["algorithm"],
|
||||
nonce=base64.b64decode(data["nonce"]) if data.get("nonce") else None,
|
||||
tag=base64.b64decode(data["tag"]) if data.get("tag") else None
|
||||
)
|
||||
|
||||
|
||||
class EncryptionService:
|
||||
"""Service for encrypting/decrypting confidential transaction data"""
|
||||
|
||||
def __init__(self, key_manager: "KeyManager"):
|
||||
self.key_manager = key_manager
|
||||
self.backend = default_backend()
|
||||
self.algorithm = "AES-256-GCM+X25519"
|
||||
|
||||
def encrypt(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
participants: List[str],
|
||||
include_audit: bool = True
|
||||
) -> EncryptedData:
|
||||
"""Encrypt data for multiple participants
|
||||
|
||||
Args:
|
||||
data: Data to encrypt
|
||||
participants: List of participant IDs who can decrypt
|
||||
include_audit: Whether to include audit escrow key
|
||||
|
||||
Returns:
|
||||
EncryptedData container with ciphertext and encrypted keys
|
||||
"""
|
||||
try:
|
||||
# Generate random DEK (Data Encryption Key)
|
||||
dek = os.urandom(32) # 256-bit key for AES-256
|
||||
nonce = os.urandom(12) # 96-bit nonce for GCM
|
||||
|
||||
# Serialize and encrypt data
|
||||
plaintext = json.dumps(data, separators=(",", ":")).encode()
|
||||
aesgcm = AESGCM(dek)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
|
||||
|
||||
# Extract tag (included in ciphertext for GCM)
|
||||
tag = ciphertext[-16:]
|
||||
actual_ciphertext = ciphertext[:-16]
|
||||
|
||||
# Encrypt DEK for each participant
|
||||
encrypted_keys = {}
|
||||
for participant in participants:
|
||||
try:
|
||||
public_key = self.key_manager.get_public_key(participant)
|
||||
encrypted_dek = self._encrypt_dek(dek, public_key)
|
||||
encrypted_keys[participant] = encrypted_dek
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt DEK for participant {participant}: {e}")
|
||||
continue
|
||||
|
||||
# Add audit escrow if requested
|
||||
if include_audit:
|
||||
try:
|
||||
audit_public_key = self.key_manager.get_audit_key()
|
||||
encrypted_dek = self._encrypt_dek(dek, audit_public_key)
|
||||
encrypted_keys["audit"] = encrypted_dek
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt DEK for audit: {e}")
|
||||
|
||||
return EncryptedData(
|
||||
ciphertext=actual_ciphertext,
|
||||
encrypted_keys=encrypted_keys,
|
||||
algorithm=self.algorithm,
|
||||
nonce=nonce,
|
||||
tag=tag
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Encryption failed: {e}")
|
||||
raise EncryptionError(f"Failed to encrypt data: {e}")
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
encrypted_data: EncryptedData,
|
||||
participant_id: str,
|
||||
purpose: str = "access"
|
||||
) -> Dict[str, Any]:
|
||||
"""Decrypt data for a specific participant
|
||||
|
||||
Args:
|
||||
encrypted_data: The encrypted data container
|
||||
participant_id: ID of the participant requesting decryption
|
||||
purpose: Purpose of decryption for audit logging
|
||||
|
||||
Returns:
|
||||
Decrypted data as dictionary
|
||||
"""
|
||||
try:
|
||||
# Get participant's private key
|
||||
private_key = self.key_manager.get_private_key(participant_id)
|
||||
|
||||
# Get encrypted DEK for participant
|
||||
if participant_id not in encrypted_data.encrypted_keys:
|
||||
raise AccessDeniedError(f"Participant {participant_id} not authorized")
|
||||
|
||||
encrypted_dek = encrypted_data.encrypted_keys[participant_id]
|
||||
|
||||
# Decrypt DEK
|
||||
dek = self._decrypt_dek(encrypted_dek, private_key)
|
||||
|
||||
# Reconstruct ciphertext with tag
|
||||
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
|
||||
|
||||
# Decrypt data
|
||||
aesgcm = AESGCM(dek)
|
||||
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
|
||||
|
||||
data = json.loads(plaintext.decode())
|
||||
|
||||
# Log access
|
||||
self._log_access(
|
||||
transaction_id=None, # Will be set by caller
|
||||
participant_id=participant_id,
|
||||
purpose=purpose,
|
||||
success=True
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Decryption failed for participant {participant_id}: {e}")
|
||||
self._log_access(
|
||||
transaction_id=None,
|
||||
participant_id=participant_id,
|
||||
purpose=purpose,
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
raise DecryptionError(f"Failed to decrypt data: {e}")
|
||||
|
||||
def audit_decrypt(
|
||||
self,
|
||||
encrypted_data: EncryptedData,
|
||||
audit_authorization: str,
|
||||
purpose: str = "audit"
|
||||
) -> Dict[str, Any]:
|
||||
"""Decrypt data for audit purposes
|
||||
|
||||
Args:
|
||||
encrypted_data: The encrypted data container
|
||||
audit_authorization: Authorization token for audit access
|
||||
purpose: Purpose of decryption
|
||||
|
||||
Returns:
|
||||
Decrypted data as dictionary
|
||||
"""
|
||||
try:
|
||||
# Verify audit authorization
|
||||
if not self.key_manager.verify_audit_authorization(audit_authorization):
|
||||
raise AccessDeniedError("Invalid audit authorization")
|
||||
|
||||
# Get audit private key
|
||||
audit_private_key = self.key_manager.get_audit_private_key(audit_authorization)
|
||||
|
||||
# Decrypt using audit key
|
||||
if "audit" not in encrypted_data.encrypted_keys:
|
||||
raise AccessDeniedError("Audit escrow not available")
|
||||
|
||||
encrypted_dek = encrypted_data.encrypted_keys["audit"]
|
||||
dek = self._decrypt_dek(encrypted_dek, audit_private_key)
|
||||
|
||||
# Decrypt data
|
||||
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
|
||||
aesgcm = AESGCM(dek)
|
||||
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
|
||||
|
||||
data = json.loads(plaintext.decode())
|
||||
|
||||
# Log audit access
|
||||
self._log_access(
|
||||
transaction_id=None,
|
||||
participant_id="audit",
|
||||
purpose=f"audit:{purpose}",
|
||||
success=True,
|
||||
authorization=audit_authorization
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Audit decryption failed: {e}")
|
||||
raise DecryptionError(f"Failed to decrypt for audit: {e}")
|
||||
|
||||
def _encrypt_dek(self, dek: bytes, public_key: X25519PublicKey) -> bytes:
|
||||
"""Encrypt DEK using ECIES with X25519"""
|
||||
# Generate ephemeral key pair
|
||||
ephemeral_private = X25519PrivateKey.generate()
|
||||
ephemeral_public = ephemeral_private.public_key()
|
||||
|
||||
# Perform ECDH
|
||||
shared_key = ephemeral_private.exchange(public_key)
|
||||
|
||||
# Derive encryption key from shared secret
|
||||
derived_key = HKDF(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=None,
|
||||
info=b"AITBC-DEK-Encryption",
|
||||
backend=self.backend
|
||||
).derive(shared_key)
|
||||
|
||||
# Encrypt DEK with AES-GCM
|
||||
aesgcm = AESGCM(derived_key)
|
||||
nonce = os.urandom(12)
|
||||
encrypted_dek = aesgcm.encrypt(nonce, dek, None)
|
||||
|
||||
# Return ephemeral public key + nonce + encrypted DEK
|
||||
return (
|
||||
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw) +
|
||||
nonce +
|
||||
encrypted_dek
|
||||
)
|
||||
|
||||
def _decrypt_dek(self, encrypted_dek: bytes, private_key: X25519PrivateKey) -> bytes:
|
||||
"""Decrypt DEK using ECIES with X25519"""
|
||||
# Extract components
|
||||
ephemeral_public_bytes = encrypted_dek[:32]
|
||||
nonce = encrypted_dek[32:44]
|
||||
dek_ciphertext = encrypted_dek[44:]
|
||||
|
||||
# Reconstruct ephemeral public key
|
||||
ephemeral_public = X25519PublicKey.from_public_bytes(ephemeral_public_bytes)
|
||||
|
||||
# Perform ECDH
|
||||
shared_key = private_key.exchange(ephemeral_public)
|
||||
|
||||
# Derive decryption key
|
||||
derived_key = HKDF(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=None,
|
||||
info=b"AITBC-DEK-Encryption",
|
||||
backend=self.backend
|
||||
).derive(shared_key)
|
||||
|
||||
# Decrypt DEK
|
||||
aesgcm = AESGCM(derived_key)
|
||||
dek = aesgcm.decrypt(nonce, dek_ciphertext, None)
|
||||
|
||||
return dek
|
||||
|
||||
def _log_access(
|
||||
self,
|
||||
transaction_id: Optional[str],
|
||||
participant_id: str,
|
||||
purpose: str,
|
||||
success: bool,
|
||||
error: Optional[str] = None,
|
||||
authorization: Optional[str] = None
|
||||
):
|
||||
"""Log access to confidential data"""
|
||||
try:
|
||||
log_entry = {
|
||||
"transaction_id": transaction_id,
|
||||
"participant_id": participant_id,
|
||||
"purpose": purpose,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"success": success,
|
||||
"error": error,
|
||||
"authorization": authorization
|
||||
}
|
||||
|
||||
# In production, this would go to secure audit log
|
||||
logger.info(f"Confidential data access: {json.dumps(log_entry)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log access: {e}")
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Base exception for encryption errors"""
|
||||
pass
|
||||
|
||||
|
||||
class DecryptionError(EncryptionError):
|
||||
"""Exception for decryption errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AccessDeniedError(EncryptionError):
|
||||
"""Exception for access denied errors"""
|
||||
pass
|
||||
435
apps/coordinator-api/src/app/services/hsm_key_manager.py
Normal file
435
apps/coordinator-api/src/app/services/hsm_key_manager.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
HSM-backed key management for production use
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from ..models import KeyPair, KeyRotationLog, AuditAuthorization
|
||||
from ..repositories.confidential import (
|
||||
ParticipantKeyRepository,
|
||||
KeyRotationRepository
|
||||
)
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HSMProvider(ABC):
|
||||
"""Abstract base class for HSM providers"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate key pair in HSM, return (public_key, key_handle)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
|
||||
"""Sign data with HSM-stored private key"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
|
||||
"""Derive shared secret using ECDH"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_key(self, key_handle: bytes) -> bool:
|
||||
"""Delete key from HSM"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_keys(self) -> List[str]:
|
||||
"""List all key IDs in HSM"""
|
||||
pass
|
||||
|
||||
|
||||
class SoftwareHSMProvider(HSMProvider):
|
||||
"""Software-based HSM provider for development/testing"""
|
||||
|
||||
def __init__(self):
|
||||
self._keys: Dict[str, X25519PrivateKey] = {}
|
||||
self._backend = default_backend()
|
||||
|
||||
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate key pair in memory"""
|
||||
private_key = X25519PrivateKey.generate()
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Store private key (in production, this would be in secure hardware)
|
||||
self._keys[key_id] = private_key
|
||||
|
||||
return (
|
||||
public_key.public_bytes(Encoding.Raw, PublicFormat.Raw),
|
||||
key_id.encode() # Use key_id as handle
|
||||
)
|
||||
|
||||
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
|
||||
"""Sign with stored private key"""
|
||||
key_id = key_handle.decode()
|
||||
private_key = self._keys.get(key_id)
|
||||
|
||||
if not private_key:
|
||||
raise ValueError(f"Key not found: {key_id}")
|
||||
|
||||
# For X25519, we don't sign - we exchange
|
||||
# This is a placeholder for actual HSM operations
|
||||
return b"signature_placeholder"
|
||||
|
||||
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
|
||||
"""Derive shared secret"""
|
||||
key_id = key_handle.decode()
|
||||
private_key = self._keys.get(key_id)
|
||||
|
||||
if not private_key:
|
||||
raise ValueError(f"Key not found: {key_id}")
|
||||
|
||||
peer_public = X25519PublicKey.from_public_bytes(public_key)
|
||||
return private_key.exchange(peer_public)
|
||||
|
||||
async def delete_key(self, key_handle: bytes) -> bool:
|
||||
"""Delete key from memory"""
|
||||
key_id = key_handle.decode()
|
||||
if key_id in self._keys:
|
||||
del self._keys[key_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def list_keys(self) -> List[str]:
|
||||
"""List all keys"""
|
||||
return list(self._keys.keys())
|
||||
|
||||
|
||||
class AzureKeyVaultProvider(HSMProvider):
|
||||
"""Azure Key Vault HSM provider for production"""
|
||||
|
||||
def __init__(self, vault_url: str, credential):
|
||||
from azure.keyvault.keys.crypto import CryptographyClient
|
||||
from azure.keyvault.keys import KeyClient
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
self.vault_url = vault_url
|
||||
self.credential = credential or DefaultAzureCredential()
|
||||
self.key_client = KeyClient(vault_url, self.credential)
|
||||
self.crypto_client = None
|
||||
|
||||
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate key in Azure Key Vault"""
|
||||
# Create EC-HSM key
|
||||
key = await self.key_client.create_ec_key(
|
||||
key_id,
|
||||
curve="P-256" # Azure doesn't support X25519 directly
|
||||
)
|
||||
|
||||
# Get public key
|
||||
public_key = key.key.cryptography_client.public_key()
|
||||
public_bytes = public_key.public_bytes(
|
||||
Encoding.Raw,
|
||||
PublicFormat.Raw
|
||||
)
|
||||
|
||||
return public_bytes, key.id.encode()
|
||||
|
||||
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
|
||||
"""Sign with Azure Key Vault"""
|
||||
key_id = key_handle.decode()
|
||||
crypto_client = self.key_client.get_cryptography_client(key_id)
|
||||
|
||||
sign_result = await crypto_client.sign("ES256", data)
|
||||
return sign_result.signature
|
||||
|
||||
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
|
||||
"""Derive shared secret (not directly supported in Azure)"""
|
||||
# Would need to use a different approach
|
||||
raise NotImplementedError("ECDH not supported in Azure Key Vault")
|
||||
|
||||
async def delete_key(self, key_handle: bytes) -> bool:
|
||||
"""Delete key from Azure Key Vault"""
|
||||
key_name = key_handle.decode().split("/")[-1]
|
||||
await self.key_client.begin_delete_key(key_name)
|
||||
return True
|
||||
|
||||
async def list_keys(self) -> List[str]:
|
||||
"""List keys in Azure Key Vault"""
|
||||
keys = []
|
||||
async for key in self.key_client.list_properties_of_keys():
|
||||
keys.append(key.name)
|
||||
return keys
|
||||
|
||||
|
||||
class AWSKMSProvider(HSMProvider):
|
||||
"""AWS KMS HSM provider for production"""
|
||||
|
||||
def __init__(self, region_name: str):
|
||||
import boto3
|
||||
self.kms = boto3.client('kms', region_name=region_name)
|
||||
|
||||
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate key pair in AWS KMS"""
|
||||
# Create CMK
|
||||
response = self.kms.create_key(
|
||||
Description=f"AITBC confidential transaction key for {key_id}",
|
||||
KeyUsage='ENCRYPT_DECRYPT',
|
||||
KeySpec='ECC_NIST_P256'
|
||||
)
|
||||
|
||||
# Get public key
|
||||
public_key = self.kms.get_public_key(KeyId=response['KeyMetadata']['KeyId'])
|
||||
|
||||
return public_key['PublicKey'], response['KeyMetadata']['KeyId'].encode()
|
||||
|
||||
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
|
||||
"""Sign with AWS KMS"""
|
||||
response = self.kms.sign(
|
||||
KeyId=key_handle.decode(),
|
||||
Message=data,
|
||||
MessageType='RAW',
|
||||
SigningAlgorithm='ECDSA_SHA_256'
|
||||
)
|
||||
return response['Signature']
|
||||
|
||||
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
|
||||
"""Derive shared secret (not directly supported in KMS)"""
|
||||
raise NotImplementedError("ECDH not supported in AWS KMS")
|
||||
|
||||
async def delete_key(self, key_handle: bytes) -> bool:
|
||||
"""Schedule key deletion in AWS KMS"""
|
||||
self.kms.schedule_key_deletion(KeyId=key_handle.decode())
|
||||
return True
|
||||
|
||||
async def list_keys(self) -> List[str]:
|
||||
"""List keys in AWS KMS"""
|
||||
keys = []
|
||||
paginator = self.kms.get_paginator('list_keys')
|
||||
for page in paginator.paginate():
|
||||
for key in page['Keys']:
|
||||
keys.append(key['KeyId'])
|
||||
return keys
|
||||
|
||||
|
||||
class HSMKeyManager:
|
||||
"""HSM-backed key manager for production"""
|
||||
|
||||
def __init__(self, hsm_provider: HSMProvider, key_repository: ParticipantKeyRepository):
|
||||
self.hsm = hsm_provider
|
||||
self.key_repo = key_repository
|
||||
self._master_key = None
|
||||
self._init_master_key()
|
||||
|
||||
def _init_master_key(self):
|
||||
"""Initialize master key for encrypting stored data"""
|
||||
# In production, this would come from HSM or KMS
|
||||
self._master_key = os.urandom(32)
|
||||
|
||||
async def generate_key_pair(self, participant_id: str) -> KeyPair:
|
||||
"""Generate key pair in HSM"""
|
||||
try:
|
||||
# Generate key in HSM
|
||||
hsm_key_id = f"aitbc-{participant_id}-{datetime.utcnow().timestamp()}"
|
||||
public_key_bytes, key_handle = await self.hsm.generate_key(hsm_key_id)
|
||||
|
||||
# Create key pair record
|
||||
key_pair = KeyPair(
|
||||
participant_id=participant_id,
|
||||
private_key=key_handle, # Store HSM handle, not actual private key
|
||||
public_key=public_key_bytes,
|
||||
algorithm="X25519",
|
||||
created_at=datetime.utcnow(),
|
||||
version=1
|
||||
)
|
||||
|
||||
# Store metadata in database
|
||||
await self.key_repo.create(
|
||||
await self._get_session(),
|
||||
key_pair
|
||||
)
|
||||
|
||||
logger.info(f"Generated HSM key pair for participant: {participant_id}")
|
||||
return key_pair
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate HSM key pair for {participant_id}: {e}")
|
||||
raise
|
||||
|
||||
async def rotate_keys(self, participant_id: str) -> KeyPair:
|
||||
"""Rotate keys in HSM"""
|
||||
# Get current key
|
||||
current_key = await self.key_repo.get_by_participant(
|
||||
await self._get_session(),
|
||||
participant_id
|
||||
)
|
||||
|
||||
if not current_key:
|
||||
raise ValueError(f"No existing keys for {participant_id}")
|
||||
|
||||
# Generate new key
|
||||
new_key_pair = await self.generate_key_pair(participant_id)
|
||||
|
||||
# Log rotation
|
||||
rotation_log = KeyRotationLog(
|
||||
participant_id=participant_id,
|
||||
old_version=current_key.version,
|
||||
new_version=new_key_pair.version,
|
||||
rotated_at=datetime.utcnow(),
|
||||
reason="scheduled_rotation"
|
||||
)
|
||||
|
||||
await self.key_repo.rotate(
|
||||
await self._get_session(),
|
||||
participant_id,
|
||||
new_key_pair
|
||||
)
|
||||
|
||||
# Delete old key from HSM
|
||||
await self.hsm.delete_key(current_key.private_key)
|
||||
|
||||
return new_key_pair
|
||||
|
||||
def get_public_key(self, participant_id: str) -> X25519PublicKey:
|
||||
"""Get public key for participant"""
|
||||
key = self.key_repo.get_by_participant_sync(participant_id)
|
||||
if not key:
|
||||
raise ValueError(f"No keys found for {participant_id}")
|
||||
|
||||
return X25519PublicKey.from_public_bytes(key.public_key)
|
||||
|
||||
async def get_private_key_handle(self, participant_id: str) -> bytes:
|
||||
"""Get HSM key handle for participant"""
|
||||
key = await self.key_repo.get_by_participant(
|
||||
await self._get_session(),
|
||||
participant_id
|
||||
)
|
||||
|
||||
if not key:
|
||||
raise ValueError(f"No keys found for {participant_id}")
|
||||
|
||||
return key.private_key # This is the HSM handle
|
||||
|
||||
async def derive_shared_secret(
|
||||
self,
|
||||
participant_id: str,
|
||||
peer_public_key: bytes
|
||||
) -> bytes:
|
||||
"""Derive shared secret using HSM"""
|
||||
key_handle = await self.get_private_key_handle(participant_id)
|
||||
return await self.hsm.derive_shared_secret(key_handle, peer_public_key)
|
||||
|
||||
async def sign_with_key(
|
||||
self,
|
||||
participant_id: str,
|
||||
data: bytes
|
||||
) -> bytes:
|
||||
"""Sign data using HSM-stored key"""
|
||||
key_handle = await self.get_private_key_handle(participant_id)
|
||||
return await self.hsm.sign_with_key(key_handle, data)
|
||||
|
||||
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
|
||||
"""Revoke participant's keys"""
|
||||
# Get current key
|
||||
current_key = await self.key_repo.get_by_participant(
|
||||
await self._get_session(),
|
||||
participant_id
|
||||
)
|
||||
|
||||
if not current_key:
|
||||
return False
|
||||
|
||||
# Delete from HSM
|
||||
await self.hsm.delete_key(current_key.private_key)
|
||||
|
||||
# Mark as revoked in database
|
||||
return await self.key_repo.update_active(
|
||||
await self._get_session(),
|
||||
participant_id,
|
||||
False,
|
||||
reason
|
||||
)
|
||||
|
||||
async def create_audit_authorization(
|
||||
self,
|
||||
issuer: str,
|
||||
purpose: str,
|
||||
expires_in_hours: int = 24
|
||||
) -> str:
|
||||
"""Create audit authorization signed with HSM"""
|
||||
# Create authorization payload
|
||||
payload = {
|
||||
"issuer": issuer,
|
||||
"subject": "audit_access",
|
||||
"purpose": purpose,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"expires_at": (datetime.utcnow() + timedelta(hours=expires_in_hours)).isoformat()
|
||||
}
|
||||
|
||||
# Sign with audit key
|
||||
audit_key_handle = await self.get_private_key_handle("audit")
|
||||
signature = await self.hsm.sign_with_key(
|
||||
audit_key_handle,
|
||||
json.dumps(payload).encode()
|
||||
)
|
||||
|
||||
payload["signature"] = signature.hex()
|
||||
|
||||
# Encode for transport
|
||||
import base64
|
||||
return base64.b64encode(json.dumps(payload).encode()).decode()
|
||||
|
||||
async def verify_audit_authorization(self, authorization: str) -> bool:
|
||||
"""Verify audit authorization"""
|
||||
try:
|
||||
# Decode authorization
|
||||
import base64
|
||||
auth_data = base64.b64decode(authorization).decode()
|
||||
auth_json = json.loads(auth_data)
|
||||
|
||||
# Check expiration
|
||||
expires_at = datetime.fromisoformat(auth_json["expires_at"])
|
||||
if datetime.utcnow() > expires_at:
|
||||
return False
|
||||
|
||||
# Verify signature with audit public key
|
||||
audit_public_key = self.get_public_key("audit")
|
||||
# In production, verify with proper cryptographic library
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify audit authorization: {e}")
|
||||
return False
|
||||
|
||||
async def _get_session(self):
|
||||
"""Get database session"""
|
||||
# In production, inject via dependency injection
|
||||
async for session in get_async_session():
|
||||
return session
|
||||
|
||||
|
||||
def create_hsm_key_manager() -> HSMKeyManager:
|
||||
"""Create HSM key manager based on configuration"""
|
||||
from ..repositories.confidential import ParticipantKeyRepository
|
||||
|
||||
# Get HSM provider from settings
|
||||
hsm_type = getattr(settings, 'HSM_PROVIDER', 'software')
|
||||
|
||||
if hsm_type == 'software':
|
||||
hsm = SoftwareHSMProvider()
|
||||
elif hsm_type == 'azure':
|
||||
vault_url = getattr(settings, 'AZURE_KEY_VAULT_URL')
|
||||
hsm = AzureKeyVaultProvider(vault_url)
|
||||
elif hsm_type == 'aws':
|
||||
region = getattr(settings, 'AWS_REGION', 'us-east-1')
|
||||
hsm = AWSKMSProvider(region)
|
||||
else:
|
||||
raise ValueError(f"Unknown HSM provider: {hsm_type}")
|
||||
|
||||
key_repo = ParticipantKeyRepository()
|
||||
|
||||
return HSMKeyManager(hsm, key_repo)
|
||||
466
apps/coordinator-api/src/app/services/key_management.py
Normal file
466
apps/coordinator-api/src/app/services/key_management.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
Key management service for confidential transactions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ..models import KeyPair, KeyRotationLog, AuditAuthorization
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KeyManager:
|
||||
"""Manages encryption keys for confidential transactions"""
|
||||
|
||||
def __init__(self, storage_backend: "KeyStorageBackend"):
|
||||
self.storage = storage_backend
|
||||
self.backend = default_backend()
|
||||
self._key_cache = {}
|
||||
self._audit_key = None
|
||||
self._audit_key_rotation = timedelta(days=30)
|
||||
|
||||
async def generate_key_pair(self, participant_id: str) -> KeyPair:
|
||||
"""Generate X25519 key pair for participant"""
|
||||
try:
|
||||
# Generate new key pair
|
||||
private_key = X25519PrivateKey.generate()
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Create key pair object
|
||||
key_pair = KeyPair(
|
||||
participant_id=participant_id,
|
||||
private_key=private_key.private_bytes_raw(),
|
||||
public_key=public_key.public_bytes_raw(),
|
||||
algorithm="X25519",
|
||||
created_at=datetime.utcnow(),
|
||||
version=1
|
||||
)
|
||||
|
||||
# Store securely
|
||||
await self.storage.store_key_pair(key_pair)
|
||||
|
||||
# Cache public key
|
||||
self._key_cache[participant_id] = {
|
||||
"public_key": public_key,
|
||||
"version": key_pair.version
|
||||
}
|
||||
|
||||
logger.info(f"Generated key pair for participant: {participant_id}")
|
||||
return key_pair
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate key pair for {participant_id}: {e}")
|
||||
raise KeyManagementError(f"Key generation failed: {e}")
|
||||
|
||||
async def rotate_keys(self, participant_id: str) -> KeyPair:
|
||||
"""Rotate encryption keys for participant"""
|
||||
try:
|
||||
# Get current key pair
|
||||
current_key = await self.storage.get_key_pair(participant_id)
|
||||
if not current_key:
|
||||
raise KeyNotFoundError(f"No existing keys for {participant_id}")
|
||||
|
||||
# Generate new key pair
|
||||
new_key_pair = await self.generate_key_pair(participant_id)
|
||||
|
||||
# Log rotation
|
||||
rotation_log = KeyRotationLog(
|
||||
participant_id=participant_id,
|
||||
old_version=current_key.version,
|
||||
new_version=new_key_pair.version,
|
||||
rotated_at=datetime.utcnow(),
|
||||
reason="scheduled_rotation"
|
||||
)
|
||||
await self.storage.log_rotation(rotation_log)
|
||||
|
||||
# Re-encrypt active transactions (in production)
|
||||
await self._reencrypt_transactions(participant_id, current_key, new_key_pair)
|
||||
|
||||
logger.info(f"Rotated keys for participant: {participant_id}")
|
||||
return new_key_pair
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rotate keys for {participant_id}: {e}")
|
||||
raise KeyManagementError(f"Key rotation failed: {e}")
|
||||
|
||||
def get_public_key(self, participant_id: str) -> X25519PublicKey:
|
||||
"""Get public key for participant"""
|
||||
# Check cache first
|
||||
if participant_id in self._key_cache:
|
||||
return self._key_cache[participant_id]["public_key"]
|
||||
|
||||
# Load from storage
|
||||
key_pair = self.storage.get_key_pair_sync(participant_id)
|
||||
if not key_pair:
|
||||
raise KeyNotFoundError(f"No keys found for participant: {participant_id}")
|
||||
|
||||
# Reconstruct public key
|
||||
public_key = X25519PublicKey.from_public_bytes(key_pair.public_key)
|
||||
|
||||
# Cache it
|
||||
self._key_cache[participant_id] = {
|
||||
"public_key": public_key,
|
||||
"version": key_pair.version
|
||||
}
|
||||
|
||||
return public_key
|
||||
|
||||
def get_private_key(self, participant_id: str) -> X25519PrivateKey:
|
||||
"""Get private key for participant (from secure storage)"""
|
||||
key_pair = self.storage.get_key_pair_sync(participant_id)
|
||||
if not key_pair:
|
||||
raise KeyNotFoundError(f"No keys found for participant: {participant_id}")
|
||||
|
||||
# Reconstruct private key
|
||||
private_key = X25519PrivateKey.from_private_bytes(key_pair.private_key)
|
||||
return private_key
|
||||
|
||||
async def get_audit_key(self) -> X25519PublicKey:
|
||||
"""Get public audit key for escrow"""
|
||||
if not self._audit_key or self._should_rotate_audit_key():
|
||||
await self._rotate_audit_key()
|
||||
|
||||
return self._audit_key
|
||||
|
||||
async def get_audit_private_key(self, authorization: str) -> X25519PrivateKey:
|
||||
"""Get private audit key with authorization"""
|
||||
# Verify authorization
|
||||
if not await self.verify_audit_authorization(authorization):
|
||||
raise AccessDeniedError("Invalid audit authorization")
|
||||
|
||||
# Load audit key from secure storage
|
||||
audit_key_data = await self.storage.get_audit_key()
|
||||
if not audit_key_data:
|
||||
raise KeyNotFoundError("Audit key not found")
|
||||
|
||||
return X25519PrivateKey.from_private_bytes(audit_key_data.private_key)
|
||||
|
||||
async def verify_audit_authorization(self, authorization: str) -> bool:
|
||||
"""Verify audit authorization token"""
|
||||
try:
|
||||
# Decode authorization
|
||||
auth_data = base64.b64decode(authorization).decode()
|
||||
auth_json = json.loads(auth_data)
|
||||
|
||||
# Check expiration
|
||||
expires_at = datetime.fromisoformat(auth_json["expires_at"])
|
||||
if datetime.utcnow() > expires_at:
|
||||
return False
|
||||
|
||||
# Verify signature (in production, use proper signature verification)
|
||||
# For now, just check format
|
||||
required_fields = ["issuer", "subject", "expires_at", "signature"]
|
||||
return all(field in auth_json for field in required_fields)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify audit authorization: {e}")
|
||||
return False
|
||||
|
||||
async def create_audit_authorization(
|
||||
self,
|
||||
issuer: str,
|
||||
purpose: str,
|
||||
expires_in_hours: int = 24
|
||||
) -> str:
|
||||
"""Create audit authorization token"""
|
||||
try:
|
||||
# Create authorization payload
|
||||
payload = {
|
||||
"issuer": issuer,
|
||||
"subject": "audit_access",
|
||||
"purpose": purpose,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"expires_at": (datetime.utcnow() + timedelta(hours=expires_in_hours)).isoformat(),
|
||||
"signature": "placeholder" # In production, sign with issuer key
|
||||
}
|
||||
|
||||
# Encode and return
|
||||
auth_json = json.dumps(payload)
|
||||
return base64.b64encode(auth_json.encode()).decode()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create audit authorization: {e}")
|
||||
raise KeyManagementError(f"Authorization creation failed: {e}")
|
||||
|
||||
async def list_participants(self) -> List[str]:
|
||||
"""List all participants with keys"""
|
||||
return await self.storage.list_participants()
|
||||
|
||||
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
|
||||
"""Revoke participant's keys"""
|
||||
try:
|
||||
# Mark keys as revoked
|
||||
success = await self.storage.revoke_keys(participant_id, reason)
|
||||
|
||||
if success:
|
||||
# Clear cache
|
||||
if participant_id in self._key_cache:
|
||||
del self._key_cache[participant_id]
|
||||
|
||||
logger.info(f"Revoked keys for participant: {participant_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke keys for {participant_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _rotate_audit_key(self):
|
||||
"""Rotate the audit escrow key"""
|
||||
try:
|
||||
# Generate new audit key pair
|
||||
audit_private = X25519PrivateKey.generate()
|
||||
audit_public = audit_private.public_key()
|
||||
|
||||
# Store securely
|
||||
audit_key_pair = KeyPair(
|
||||
participant_id="audit",
|
||||
private_key=audit_private.private_bytes_raw(),
|
||||
public_key=audit_public.public_bytes_raw(),
|
||||
algorithm="X25519",
|
||||
created_at=datetime.utcnow(),
|
||||
version=1
|
||||
)
|
||||
|
||||
await self.storage.store_audit_key(audit_key_pair)
|
||||
self._audit_key = audit_public
|
||||
|
||||
logger.info("Rotated audit escrow key")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rotate audit key: {e}")
|
||||
raise KeyManagementError(f"Audit key rotation failed: {e}")
|
||||
|
||||
def _should_rotate_audit_key(self) -> bool:
|
||||
"""Check if audit key needs rotation"""
|
||||
# In production, check last rotation time
|
||||
return self._audit_key is None
|
||||
|
||||
async def _reencrypt_transactions(
|
||||
self,
|
||||
participant_id: str,
|
||||
old_key_pair: KeyPair,
|
||||
new_key_pair: KeyPair
|
||||
):
|
||||
"""Re-encrypt active transactions with new key"""
|
||||
# This would be implemented in production
|
||||
# For now, just log the action
|
||||
logger.info(f"Would re-encrypt transactions for {participant_id}")
|
||||
pass
|
||||
|
||||
|
||||
class KeyStorageBackend:
|
||||
"""Abstract base for key storage backends"""
|
||||
|
||||
async def store_key_pair(self, key_pair: KeyPair) -> bool:
|
||||
"""Store key pair securely"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_key_pair(self, participant_id: str) -> Optional[KeyPair]:
|
||||
"""Get key pair for participant"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_key_pair_sync(self, participant_id: str) -> Optional[KeyPair]:
|
||||
"""Synchronous get key pair"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def store_audit_key(self, key_pair: KeyPair) -> bool:
|
||||
"""Store audit key pair"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_audit_key(self) -> Optional[KeyPair]:
|
||||
"""Get audit key pair"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_participants(self) -> List[str]:
|
||||
"""List all participants"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
|
||||
"""Revoke keys for participant"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def log_rotation(self, rotation_log: KeyRotationLog) -> bool:
|
||||
"""Log key rotation"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileKeyStorage(KeyStorageBackend):
|
||||
"""File-based key storage for development"""
|
||||
|
||||
def __init__(self, storage_path: str):
|
||||
self.storage_path = storage_path
|
||||
os.makedirs(storage_path, exist_ok=True)
|
||||
|
||||
async def store_key_pair(self, key_pair: KeyPair) -> bool:
|
||||
"""Store key pair to file"""
|
||||
try:
|
||||
file_path = os.path.join(self.storage_path, f"{key_pair.participant_id}.json")
|
||||
|
||||
# Store private key in separate encrypted file
|
||||
private_path = os.path.join(self.storage_path, f"{key_pair.participant_id}.priv")
|
||||
|
||||
# In production, encrypt private key with master key
|
||||
with open(private_path, "wb") as f:
|
||||
f.write(key_pair.private_key)
|
||||
|
||||
# Store public metadata
|
||||
metadata = {
|
||||
"participant_id": key_pair.participant_id,
|
||||
"public_key": base64.b64encode(key_pair.public_key).decode(),
|
||||
"algorithm": key_pair.algorithm,
|
||||
"created_at": key_pair.created_at.isoformat(),
|
||||
"version": key_pair.version
|
||||
}
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store key pair: {e}")
|
||||
return False
|
||||
|
||||
async def get_key_pair(self, participant_id: str) -> Optional[KeyPair]:
|
||||
"""Get key pair from file"""
|
||||
return self.get_key_pair_sync(participant_id)
|
||||
|
||||
def get_key_pair_sync(self, participant_id: str) -> Optional[KeyPair]:
|
||||
"""Synchronous get key pair"""
|
||||
try:
|
||||
file_path = os.path.join(self.storage_path, f"{participant_id}.json")
|
||||
private_path = os.path.join(self.storage_path, f"{participant_id}.priv")
|
||||
|
||||
if not os.path.exists(file_path) or not os.path.exists(private_path):
|
||||
return None
|
||||
|
||||
# Load metadata
|
||||
with open(file_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Load private key
|
||||
with open(private_path, "rb") as f:
|
||||
private_key = f.read()
|
||||
|
||||
return KeyPair(
|
||||
participant_id=metadata["participant_id"],
|
||||
private_key=private_key,
|
||||
public_key=base64.b64decode(metadata["public_key"]),
|
||||
algorithm=metadata["algorithm"],
|
||||
created_at=datetime.fromisoformat(metadata["created_at"]),
|
||||
version=metadata["version"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key pair: {e}")
|
||||
return None
|
||||
|
||||
async def store_audit_key(self, key_pair: KeyPair) -> bool:
|
||||
"""Store audit key"""
|
||||
audit_path = os.path.join(self.storage_path, "audit.json")
|
||||
audit_priv_path = os.path.join(self.storage_path, "audit.priv")
|
||||
|
||||
try:
|
||||
# Store private key
|
||||
with open(audit_priv_path, "wb") as f:
|
||||
f.write(key_pair.private_key)
|
||||
|
||||
# Store metadata
|
||||
metadata = {
|
||||
"participant_id": "audit",
|
||||
"public_key": base64.b64encode(key_pair.public_key).decode(),
|
||||
"algorithm": key_pair.algorithm,
|
||||
"created_at": key_pair.created_at.isoformat(),
|
||||
"version": key_pair.version
|
||||
}
|
||||
|
||||
with open(audit_path, "w") as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store audit key: {e}")
|
||||
return False
|
||||
|
||||
async def get_audit_key(self) -> Optional[KeyPair]:
|
||||
"""Get audit key"""
|
||||
return self.get_key_pair_sync("audit")
|
||||
|
||||
async def list_participants(self) -> List[str]:
|
||||
"""List all participants"""
|
||||
participants = []
|
||||
for file in os.listdir(self.storage_path):
|
||||
if file.endswith(".json") and file != "audit.json":
|
||||
participant_id = file[:-5] # Remove .json
|
||||
participants.append(participant_id)
|
||||
return participants
|
||||
|
||||
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
|
||||
"""Revoke keys by deleting files"""
|
||||
try:
|
||||
file_path = os.path.join(self.storage_path, f"{participant_id}.json")
|
||||
private_path = os.path.join(self.storage_path, f"{participant_id}.priv")
|
||||
|
||||
# Move to revoked folder instead of deleting
|
||||
revoked_path = os.path.join(self.storage_path, "revoked")
|
||||
os.makedirs(revoked_path, exist_ok=True)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
os.rename(file_path, os.path.join(revoked_path, f"{participant_id}.json"))
|
||||
if os.path.exists(private_path):
|
||||
os.rename(private_path, os.path.join(revoked_path, f"{participant_id}.priv"))
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke keys: {e}")
|
||||
return False
|
||||
|
||||
async def log_rotation(self, rotation_log: KeyRotationLog) -> bool:
|
||||
"""Log key rotation"""
|
||||
log_path = os.path.join(self.storage_path, "rotations.log")
|
||||
|
||||
try:
|
||||
with open(log_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"participant_id": rotation_log.participant_id,
|
||||
"old_version": rotation_log.old_version,
|
||||
"new_version": rotation_log.new_version,
|
||||
"rotated_at": rotation_log.rotated_at.isoformat(),
|
||||
"reason": rotation_log.reason
|
||||
}) + "\n")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log rotation: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class KeyManagementError(Exception):
|
||||
"""Base exception for key management errors"""
|
||||
pass
|
||||
|
||||
|
||||
class KeyNotFoundError(KeyManagementError):
|
||||
"""Raised when key is not found"""
|
||||
pass
|
||||
|
||||
|
||||
class AccessDeniedError(KeyManagementError):
|
||||
"""Raised when access is denied"""
|
||||
pass
|
||||
526
apps/coordinator-api/src/app/services/quota_enforcement.py
Normal file
526
apps/coordinator-api/src/app/services/quota_enforcement.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
Resource quota enforcement service for multi-tenant AITBC coordinator
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, update, and_, func
|
||||
from contextlib import asynccontextmanager
|
||||
import redis
|
||||
import json
|
||||
|
||||
from ..models.multitenant import TenantQuota, UsageRecord, Tenant
|
||||
from ..exceptions import QuotaExceededError, TenantError
|
||||
from ..middleware.tenant_context import get_current_tenant_id
|
||||
|
||||
|
||||
class QuotaEnforcementService:
|
||||
"""Service for enforcing tenant resource quotas"""
|
||||
|
||||
def __init__(self, db: Session, redis_client: Optional[redis.Redis] = None):
|
||||
self.db = db
|
||||
self.redis = redis_client
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Cache for quota lookups
|
||||
self._quota_cache = {}
|
||||
self._cache_ttl = 300 # 5 minutes
|
||||
|
||||
async def check_quota(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: float,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Check if tenant has sufficient quota for a resource"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Get current quota and usage
|
||||
quota = await self._get_current_quota(tenant_id, resource_type)
|
||||
|
||||
if not quota:
|
||||
# No quota set, check if unlimited plan
|
||||
tenant = await self._get_tenant(tenant_id)
|
||||
if tenant and tenant.plan in ["enterprise", "unlimited"]:
|
||||
return True
|
||||
raise QuotaExceededError(f"No quota configured for {resource_type}")
|
||||
|
||||
# Check if adding quantity would exceed limit
|
||||
current_usage = await self._get_current_usage(tenant_id, resource_type)
|
||||
|
||||
if current_usage + quantity > quota.limit_value:
|
||||
# Log quota exceeded
|
||||
self.logger.warning(
|
||||
f"Quota exceeded for tenant {tenant_id}: "
|
||||
f"{resource_type} {current_usage + quantity}/{quota.limit_value}"
|
||||
)
|
||||
|
||||
raise QuotaExceededError(
|
||||
f"Quota exceeded for {resource_type}: "
|
||||
f"{current_usage + quantity}/{quota.limit_value}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def consume_quota(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: float,
|
||||
resource_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> UsageRecord:
|
||||
"""Consume quota and record usage"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Check quota first
|
||||
await self.check_quota(resource_type, quantity, tenant_id)
|
||||
|
||||
# Create usage record
|
||||
usage_record = UsageRecord(
|
||||
tenant_id=tenant_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
quantity=quantity,
|
||||
unit=self._get_unit_for_resource(resource_type),
|
||||
unit_price=await self._get_unit_price(resource_type),
|
||||
total_cost=await self._calculate_cost(resource_type, quantity),
|
||||
currency="USD",
|
||||
usage_start=datetime.utcnow(),
|
||||
usage_end=datetime.utcnow(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
self.db.add(usage_record)
|
||||
|
||||
# Update quota usage
|
||||
await self._update_quota_usage(tenant_id, resource_type, quantity)
|
||||
|
||||
# Update cache
|
||||
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
|
||||
if self.redis:
|
||||
current = self.redis.get(cache_key)
|
||||
if current:
|
||||
self.redis.incrbyfloat(cache_key, quantity)
|
||||
self.redis.expire(cache_key, self._cache_ttl)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(
|
||||
f"Consumed quota: tenant={tenant_id}, "
|
||||
f"resource={resource_type}, quantity={quantity}"
|
||||
)
|
||||
|
||||
return usage_record
|
||||
|
||||
async def release_quota(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: float,
|
||||
usage_record_id: str,
|
||||
tenant_id: Optional[str] = None
|
||||
):
|
||||
"""Release quota (e.g., when job completes early)"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Update usage record
|
||||
stmt = update(UsageRecord).where(
|
||||
and_(
|
||||
UsageRecord.id == usage_record_id,
|
||||
UsageRecord.tenant_id == tenant_id
|
||||
)
|
||||
).values(
|
||||
quantity=UsageRecord.quantity - quantity,
|
||||
total_cost=UsageRecord.total_cost - await self._calculate_cost(resource_type, quantity)
|
||||
)
|
||||
|
||||
result = self.db.execute(stmt)
|
||||
|
||||
if result.rowcount > 0:
|
||||
# Update quota usage
|
||||
await self._update_quota_usage(tenant_id, resource_type, -quantity)
|
||||
|
||||
# Update cache
|
||||
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
|
||||
if self.redis:
|
||||
current = self.redis.get(cache_key)
|
||||
if current:
|
||||
self.redis.incrbyfloat(cache_key, -quantity)
|
||||
self.redis.expire(cache_key, self._cache_ttl)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(
|
||||
f"Released quota: tenant={tenant_id}, "
|
||||
f"resource={resource_type}, quantity={quantity}"
|
||||
)
|
||||
|
||||
async def get_quota_status(
|
||||
self,
|
||||
resource_type: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current quota status for a tenant"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Get all quotas for tenant
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.is_active == True
|
||||
)
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
stmt = stmt.where(TenantQuota.resource_type == resource_type)
|
||||
|
||||
quotas = self.db.execute(stmt).scalars().all()
|
||||
|
||||
status = {
|
||||
"tenant_id": tenant_id,
|
||||
"quotas": {},
|
||||
"summary": {
|
||||
"total_resources": len(quotas),
|
||||
"over_limit": 0,
|
||||
"near_limit": 0
|
||||
}
|
||||
}
|
||||
|
||||
for quota in quotas:
|
||||
current_usage = await self._get_current_usage(tenant_id, quota.resource_type)
|
||||
usage_percent = (current_usage / quota.limit_value) * 100 if quota.limit_value > 0 else 0
|
||||
|
||||
quota_status = {
|
||||
"limit": float(quota.limit_value),
|
||||
"used": float(current_usage),
|
||||
"remaining": float(quota.limit_value - current_usage),
|
||||
"usage_percent": round(usage_percent, 2),
|
||||
"period": quota.period_type,
|
||||
"period_start": quota.period_start.isoformat(),
|
||||
"period_end": quota.period_end.isoformat()
|
||||
}
|
||||
|
||||
status["quotas"][quota.resource_type] = quota_status
|
||||
|
||||
# Update summary
|
||||
if usage_percent >= 100:
|
||||
status["summary"]["over_limit"] += 1
|
||||
elif usage_percent >= 80:
|
||||
status["summary"]["near_limit"] += 1
|
||||
|
||||
return status
|
||||
|
||||
@asynccontextmanager
|
||||
async def quota_reservation(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: float,
|
||||
timeout: int = 300, # 5 minutes
|
||||
tenant_id: Optional[str] = None
|
||||
):
|
||||
"""Context manager for temporary quota reservation"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
reservation_id = f"reserve:{tenant_id}:{resource_type}:{datetime.utcnow().timestamp()}"
|
||||
|
||||
try:
|
||||
# Reserve quota
|
||||
await self.check_quota(resource_type, quantity, tenant_id)
|
||||
|
||||
# Store reservation in Redis
|
||||
if self.redis:
|
||||
reservation_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"resource_type": resource_type,
|
||||
"quantity": quantity,
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
self.redis.setex(
|
||||
f"reservation:{reservation_id}",
|
||||
timeout,
|
||||
json.dumps(reservation_data)
|
||||
)
|
||||
|
||||
yield reservation_id
|
||||
|
||||
finally:
|
||||
# Clean up reservation
|
||||
if self.redis:
|
||||
self.redis.delete(f"reservation:{reservation_id}")
|
||||
|
||||
async def reset_quota_period(self, tenant_id: str, resource_type: str):
|
||||
"""Reset quota for a new period"""
|
||||
|
||||
# Get current quota
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True
|
||||
)
|
||||
)
|
||||
|
||||
quota = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not quota:
|
||||
return
|
||||
|
||||
# Calculate new period
|
||||
now = datetime.utcnow()
|
||||
if quota.period_type == "monthly":
|
||||
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = (period_start + timedelta(days=32)).replace(day=1) - timedelta(days=1)
|
||||
elif quota.period_type == "weekly":
|
||||
days_since_monday = now.weekday()
|
||||
period_start = (now - timedelta(days=days_since_monday)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
period_end = period_start + timedelta(days=6)
|
||||
else: # daily
|
||||
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + timedelta(days=1)
|
||||
|
||||
# Update quota
|
||||
quota.period_start = period_start
|
||||
quota.period_end = period_end
|
||||
quota.used_value = 0
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# Clear cache
|
||||
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
|
||||
if self.redis:
|
||||
self.redis.delete(cache_key)
|
||||
|
||||
self.logger.info(
|
||||
f"Reset quota period: tenant={tenant_id}, "
|
||||
f"resource={resource_type}, period={quota.period_type}"
|
||||
)
|
||||
|
||||
async def get_quota_alerts(self, tenant_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get quota alerts for tenants approaching or exceeding limits"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
alerts = []
|
||||
status = await self.get_quota_status(tenant_id=tenant_id)
|
||||
|
||||
for resource_type, quota_status in status["quotas"].items():
|
||||
usage_percent = quota_status["usage_percent"]
|
||||
|
||||
if usage_percent >= 100:
|
||||
alerts.append({
|
||||
"severity": "critical",
|
||||
"resource_type": resource_type,
|
||||
"message": f"Quota exceeded for {resource_type}",
|
||||
"usage_percent": usage_percent,
|
||||
"used": quota_status["used"],
|
||||
"limit": quota_status["limit"]
|
||||
})
|
||||
elif usage_percent >= 90:
|
||||
alerts.append({
|
||||
"severity": "warning",
|
||||
"resource_type": resource_type,
|
||||
"message": f"Quota almost exceeded for {resource_type}",
|
||||
"usage_percent": usage_percent,
|
||||
"used": quota_status["used"],
|
||||
"limit": quota_status["limit"]
|
||||
})
|
||||
elif usage_percent >= 80:
|
||||
alerts.append({
|
||||
"severity": "info",
|
||||
"resource_type": resource_type,
|
||||
"message": f"Quota usage high for {resource_type}",
|
||||
"usage_percent": usage_percent,
|
||||
"used": quota_status["used"],
|
||||
"limit": quota_status["limit"]
|
||||
})
|
||||
|
||||
return alerts
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _get_current_quota(self, tenant_id: str, resource_type: str) -> Optional[TenantQuota]:
|
||||
"""Get current quota for tenant and resource type"""
|
||||
|
||||
cache_key = f"quota:{tenant_id}:{resource_type}"
|
||||
|
||||
# Check cache first
|
||||
if self.redis:
|
||||
cached = self.redis.get(cache_key)
|
||||
if cached:
|
||||
quota_data = json.loads(cached)
|
||||
quota = TenantQuota(**quota_data)
|
||||
# Check if still valid
|
||||
if quota.period_end >= datetime.utcnow():
|
||||
return quota
|
||||
|
||||
# Query database
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True,
|
||||
TenantQuota.period_start <= datetime.utcnow(),
|
||||
TenantQuota.period_end >= datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
quota = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
# Cache result
|
||||
if quota and self.redis:
|
||||
quota_data = {
|
||||
"id": str(quota.id),
|
||||
"tenant_id": str(quota.tenant_id),
|
||||
"resource_type": quota.resource_type,
|
||||
"limit_value": float(quota.limit_value),
|
||||
"used_value": float(quota.used_value),
|
||||
"period_start": quota.period_start.isoformat(),
|
||||
"period_end": quota.period_end.isoformat()
|
||||
}
|
||||
self.redis.setex(
|
||||
cache_key,
|
||||
self._cache_ttl,
|
||||
json.dumps(quota_data)
|
||||
)
|
||||
|
||||
return quota
|
||||
|
||||
async def _get_current_usage(self, tenant_id: str, resource_type: str) -> float:
|
||||
"""Get current usage for tenant and resource type"""
|
||||
|
||||
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
|
||||
|
||||
# Check cache first
|
||||
if self.redis:
|
||||
cached = self.redis.get(cache_key)
|
||||
if cached:
|
||||
return float(cached)
|
||||
|
||||
# Query database
|
||||
stmt = select(func.sum(UsageRecord.quantity)).where(
|
||||
and_(
|
||||
UsageRecord.tenant_id == tenant_id,
|
||||
UsageRecord.resource_type == resource_type,
|
||||
UsageRecord.usage_start >= func.date_trunc('month', func.current_date())
|
||||
)
|
||||
)
|
||||
|
||||
result = self.db.execute(stmt).scalar()
|
||||
usage = float(result) if result else 0.0
|
||||
|
||||
# Cache result
|
||||
if self.redis:
|
||||
self.redis.setex(cache_key, self._cache_ttl, str(usage))
|
||||
|
||||
return usage
|
||||
|
||||
async def _update_quota_usage(self, tenant_id: str, resource_type: str, quantity: float):
|
||||
"""Update quota usage in database"""
|
||||
|
||||
stmt = update(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True
|
||||
)
|
||||
).values(
|
||||
used_value=TenantQuota.used_value + quantity
|
||||
)
|
||||
|
||||
self.db.execute(stmt)
|
||||
|
||||
async def _get_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Get tenant by ID"""
|
||||
stmt = select(Tenant).where(Tenant.id == tenant_id)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def _get_unit_for_resource(self, resource_type: str) -> str:
|
||||
"""Get unit for resource type"""
|
||||
unit_map = {
|
||||
"gpu_hours": "hours",
|
||||
"storage_gb": "gb",
|
||||
"api_calls": "calls",
|
||||
"bandwidth_gb": "gb",
|
||||
"compute_hours": "hours"
|
||||
}
|
||||
return unit_map.get(resource_type, "units")
|
||||
|
||||
async def _get_unit_price(self, resource_type: str) -> float:
|
||||
"""Get unit price for resource type"""
|
||||
# In a real implementation, this would come from a pricing table
|
||||
price_map = {
|
||||
"gpu_hours": 0.50, # $0.50 per hour
|
||||
"storage_gb": 0.02, # $0.02 per GB per month
|
||||
"api_calls": 0.0001, # $0.0001 per call
|
||||
"bandwidth_gb": 0.01, # $0.01 per GB
|
||||
"compute_hours": 0.30 # $0.30 per hour
|
||||
}
|
||||
return price_map.get(resource_type, 0.0)
|
||||
|
||||
async def _calculate_cost(self, resource_type: str, quantity: float) -> float:
|
||||
"""Calculate cost for resource usage"""
|
||||
unit_price = await self._get_unit_price(resource_type)
|
||||
return unit_price * quantity
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
"""Middleware to enforce quotas on API endpoints"""
|
||||
|
||||
def __init__(self, quota_service: QuotaEnforcementService):
|
||||
self.quota_service = quota_service
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Resource costs per endpoint
|
||||
self.endpoint_costs = {
|
||||
"/api/v1/jobs": {"resource": "compute_hours", "cost": 0.1},
|
||||
"/api/v1/models": {"resource": "storage_gb", "cost": 0.1},
|
||||
"/api/v1/data": {"resource": "storage_gb", "cost": 0.05},
|
||||
"/api/v1/analytics": {"resource": "api_calls", "cost": 1}
|
||||
}
|
||||
|
||||
async def check_endpoint_quota(self, endpoint: str, estimated_cost: float = 0):
|
||||
"""Check if endpoint call is within quota"""
|
||||
|
||||
resource_config = self.endpoint_costs.get(endpoint)
|
||||
if not resource_config:
|
||||
return # No quota check for this endpoint
|
||||
|
||||
try:
|
||||
await self.quota_service.check_quota(
|
||||
resource_config["resource"],
|
||||
resource_config["cost"] + estimated_cost
|
||||
)
|
||||
except QuotaExceededError as e:
|
||||
self.logger.warning(f"Quota exceeded for endpoint {endpoint}: {e}")
|
||||
raise
|
||||
|
||||
async def consume_endpoint_quota(self, endpoint: str, actual_cost: float = 0):
|
||||
"""Consume quota after endpoint execution"""
|
||||
|
||||
resource_config = self.endpoint_costs.get(endpoint)
|
||||
if not resource_config:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.quota_service.consume_quota(
|
||||
resource_config["resource"],
|
||||
resource_config["cost"] + actual_cost
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to consume quota for {endpoint}: {e}")
|
||||
# Don't fail the request, just log the error
|
||||
@@ -10,6 +10,7 @@ from sqlmodel import Session
|
||||
|
||||
from ..config import settings
|
||||
from ..domain import Job, JobReceipt
|
||||
from .zk_proofs import zk_proof_service
|
||||
|
||||
|
||||
class ReceiptService:
|
||||
@@ -24,12 +25,13 @@ class ReceiptService:
|
||||
attest_bytes = bytes.fromhex(settings.receipt_attestation_key_hex)
|
||||
self._attestation_signer = ReceiptSigner(attest_bytes)
|
||||
|
||||
def create_receipt(
|
||||
async def create_receipt(
|
||||
self,
|
||||
job: Job,
|
||||
miner_id: str,
|
||||
job_result: Dict[str, Any] | None,
|
||||
result_metrics: Dict[str, Any] | None,
|
||||
privacy_level: Optional[str] = None,
|
||||
) -> Dict[str, Any] | None:
|
||||
if self._signer is None:
|
||||
return None
|
||||
@@ -67,6 +69,32 @@ class ReceiptService:
|
||||
attestation_payload.pop("attestations", None)
|
||||
attestation_payload.pop("signature", None)
|
||||
payload["attestations"].append(self._attestation_signer.sign(attestation_payload))
|
||||
|
||||
# Generate ZK proof if privacy is requested
|
||||
if privacy_level and zk_proof_service.is_enabled():
|
||||
try:
|
||||
# Create receipt model for ZK proof generation
|
||||
receipt_model = JobReceipt(
|
||||
job_id=job.id,
|
||||
receipt_id=payload["receipt_id"],
|
||||
payload=payload
|
||||
)
|
||||
|
||||
# Generate ZK proof
|
||||
zk_proof = await zk_proof_service.generate_receipt_proof(
|
||||
receipt=receipt_model,
|
||||
job_result=job_result or {},
|
||||
privacy_level=privacy_level
|
||||
)
|
||||
|
||||
if zk_proof:
|
||||
payload["zk_proof"] = zk_proof
|
||||
payload["privacy_level"] = privacy_level
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't fail receipt creation
|
||||
print(f"Failed to generate ZK proof: {e}")
|
||||
|
||||
receipt_row = JobReceipt(job_id=job.id, receipt_id=payload["receipt_id"], payload=payload)
|
||||
self.session.add(receipt_row)
|
||||
return payload
|
||||
|
||||
690
apps/coordinator-api/src/app/services/tenant_management.py
Normal file
690
apps/coordinator-api/src/app/services/tenant_management.py
Normal file
@@ -0,0 +1,690 @@
|
||||
"""
|
||||
Tenant management service for multi-tenant AITBC coordinator
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, update, delete, and_, or_, func
|
||||
|
||||
from ..models.multitenant import (
|
||||
Tenant, TenantUser, TenantQuota, TenantApiKey,
|
||||
TenantAuditLog, TenantStatus
|
||||
)
|
||||
from ..database import get_db
|
||||
from ..exceptions import TenantError, QuotaExceededError
|
||||
|
||||
|
||||
class TenantManagementService:
|
||||
"""Service for managing tenants in multi-tenant environment"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
async def create_tenant(
|
||||
self,
|
||||
name: str,
|
||||
contact_email: str,
|
||||
plan: str = "trial",
|
||||
domain: Optional[str] = None,
|
||||
settings: Optional[Dict[str, Any]] = None,
|
||||
features: Optional[Dict[str, Any]] = None
|
||||
) -> Tenant:
|
||||
"""Create a new tenant"""
|
||||
|
||||
# Generate unique slug
|
||||
slug = self._generate_slug(name)
|
||||
if await self._tenant_exists(slug=slug):
|
||||
raise TenantError(f"Tenant with slug '{slug}' already exists")
|
||||
|
||||
# Check domain uniqueness if provided
|
||||
if domain and await self._tenant_exists(domain=domain):
|
||||
raise TenantError(f"Domain '{domain}' is already in use")
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=name,
|
||||
slug=slug,
|
||||
domain=domain,
|
||||
contact_email=contact_email,
|
||||
plan=plan,
|
||||
status=TenantStatus.PENDING.value,
|
||||
settings=settings or {},
|
||||
features=features or {}
|
||||
)
|
||||
|
||||
self.db.add(tenant)
|
||||
self.db.flush()
|
||||
|
||||
# Create default quotas
|
||||
await self._create_default_quotas(tenant.id, plan)
|
||||
|
||||
# Log creation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_created",
|
||||
event_category="lifecycle",
|
||||
actor_id="system",
|
||||
actor_type="system",
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
new_values={"name": name, "plan": plan}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Created tenant: {tenant.id} ({name})")
|
||||
|
||||
return tenant
|
||||
|
||||
async def get_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Get tenant by ID"""
|
||||
stmt = select(Tenant).where(Tenant.id == tenant_id)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def get_tenant_by_slug(self, slug: str) -> Optional[Tenant]:
|
||||
"""Get tenant by slug"""
|
||||
stmt = select(Tenant).where(Tenant.slug == slug)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def get_tenant_by_domain(self, domain: str) -> Optional[Tenant]:
|
||||
"""Get tenant by domain"""
|
||||
stmt = select(Tenant).where(Tenant.domain == domain)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def update_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
updates: Dict[str, Any],
|
||||
actor_id: str,
|
||||
actor_type: str = "user"
|
||||
) -> Tenant:
|
||||
"""Update tenant information"""
|
||||
|
||||
tenant = await self.get_tenant(tenant_id)
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
# Store old values for audit
|
||||
old_values = {
|
||||
"name": tenant.name,
|
||||
"contact_email": tenant.contact_email,
|
||||
"billing_email": tenant.billing_email,
|
||||
"settings": tenant.settings,
|
||||
"features": tenant.features
|
||||
}
|
||||
|
||||
# Apply updates
|
||||
for key, value in updates.items():
|
||||
if hasattr(tenant, key):
|
||||
setattr(tenant, key, value)
|
||||
|
||||
tenant.updated_at = datetime.utcnow()
|
||||
|
||||
# Log update
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_updated",
|
||||
event_category="lifecycle",
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
old_values=old_values,
|
||||
new_values=updates
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Updated tenant: {tenant_id}")
|
||||
|
||||
return tenant
|
||||
|
||||
async def activate_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
actor_id: str,
|
||||
actor_type: str = "user"
|
||||
) -> Tenant:
|
||||
"""Activate a tenant"""
|
||||
|
||||
tenant = await self.get_tenant(tenant_id)
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
if tenant.status == TenantStatus.ACTIVE.value:
|
||||
return tenant
|
||||
|
||||
tenant.status = TenantStatus.ACTIVE.value
|
||||
tenant.activated_at = datetime.utcnow()
|
||||
tenant.updated_at = datetime.utcnow()
|
||||
|
||||
# Log activation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_activated",
|
||||
event_category="lifecycle",
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
old_values={"status": "pending"},
|
||||
new_values={"status": "active"}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Activated tenant: {tenant_id}")
|
||||
|
||||
return tenant
|
||||
|
||||
async def deactivate_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
reason: Optional[str] = None,
|
||||
actor_id: str = "system",
|
||||
actor_type: str = "system"
|
||||
) -> Tenant:
|
||||
"""Deactivate a tenant"""
|
||||
|
||||
tenant = await self.get_tenant(tenant_id)
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
if tenant.status == TenantStatus.INACTIVE.value:
|
||||
return tenant
|
||||
|
||||
old_status = tenant.status
|
||||
tenant.status = TenantStatus.INACTIVE.value
|
||||
tenant.deactivated_at = datetime.utcnow()
|
||||
tenant.updated_at = datetime.utcnow()
|
||||
|
||||
# Revoke all API keys
|
||||
await self._revoke_all_api_keys(tenant_id)
|
||||
|
||||
# Log deactivation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_deactivated",
|
||||
event_category="lifecycle",
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
old_values={"status": old_status},
|
||||
new_values={"status": "inactive", "reason": reason}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Deactivated tenant: {tenant_id} (reason: {reason})")
|
||||
|
||||
return tenant
|
||||
|
||||
async def suspend_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
reason: Optional[str] = None,
|
||||
actor_id: str = "system",
|
||||
actor_type: str = "system"
|
||||
) -> Tenant:
|
||||
"""Suspend a tenant temporarily"""
|
||||
|
||||
tenant = await self.get_tenant(tenant_id)
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
old_status = tenant.status
|
||||
tenant.status = TenantStatus.SUSPENDED.value
|
||||
tenant.updated_at = datetime.utcnow()
|
||||
|
||||
# Log suspension
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_suspended",
|
||||
event_category="lifecycle",
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
old_values={"status": old_status},
|
||||
new_values={"status": "suspended", "reason": reason}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.warning(f"Suspended tenant: {tenant_id} (reason: {reason})")
|
||||
|
||||
return tenant
|
||||
|
||||
async def add_user_to_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: str = "member",
|
||||
permissions: Optional[List[str]] = None,
|
||||
actor_id: str = "system"
|
||||
) -> TenantUser:
|
||||
"""Add a user to a tenant"""
|
||||
|
||||
# Check if user already exists
|
||||
stmt = select(TenantUser).where(
|
||||
and_(TenantUser.tenant_id == tenant_id, TenantUser.user_id == user_id)
|
||||
)
|
||||
existing = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
raise TenantError(f"User {user_id} already belongs to tenant {tenant_id}")
|
||||
|
||||
# Create tenant user
|
||||
tenant_user = TenantUser(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
permissions=permissions or [],
|
||||
joined_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.add(tenant_user)
|
||||
|
||||
# Log addition
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant_id,
|
||||
event_type="user_added",
|
||||
event_category="access",
|
||||
actor_id=actor_id,
|
||||
actor_type="system",
|
||||
resource_type="tenant_user",
|
||||
resource_id=str(tenant_user.id),
|
||||
new_values={"user_id": user_id, "role": role}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Added user {user_id} to tenant {tenant_id}")
|
||||
|
||||
return tenant_user
|
||||
|
||||
async def remove_user_from_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
actor_id: str = "system"
|
||||
) -> bool:
|
||||
"""Remove a user from a tenant"""
|
||||
|
||||
stmt = select(TenantUser).where(
|
||||
and_(TenantUser.tenant_id == tenant_id, TenantUser.user_id == user_id)
|
||||
)
|
||||
tenant_user = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not tenant_user:
|
||||
return False
|
||||
|
||||
# Store for audit
|
||||
old_values = {
|
||||
"user_id": user_id,
|
||||
"role": tenant_user.role,
|
||||
"permissions": tenant_user.permissions
|
||||
}
|
||||
|
||||
self.db.delete(tenant_user)
|
||||
|
||||
# Log removal
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant_id,
|
||||
event_type="user_removed",
|
||||
event_category="access",
|
||||
actor_id=actor_id,
|
||||
actor_type="system",
|
||||
resource_type="tenant_user",
|
||||
resource_id=str(tenant_user.id),
|
||||
old_values=old_values
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Removed user {user_id} from tenant {tenant_id}")
|
||||
|
||||
return True
|
||||
|
||||
async def create_api_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
permissions: Optional[List[str]] = None,
|
||||
rate_limit: Optional[int] = None,
|
||||
allowed_ips: Optional[List[str]] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
created_by: str = "system"
|
||||
) -> TenantApiKey:
|
||||
"""Create a new API key for a tenant"""
|
||||
|
||||
# Generate secure key
|
||||
key_id = f"ak_{secrets.token_urlsafe(16)}"
|
||||
api_key = f"ask_{secrets.token_urlsafe(32)}"
|
||||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
|
||||
key_prefix = api_key[:8]
|
||||
|
||||
# Create API key record
|
||||
api_key_record = TenantApiKey(
|
||||
tenant_id=tenant_id,
|
||||
key_id=key_id,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
name=name,
|
||||
permissions=permissions or [],
|
||||
rate_limit=rate_limit,
|
||||
allowed_ips=allowed_ips,
|
||||
expires_at=expires_at,
|
||||
created_by=created_by
|
||||
)
|
||||
|
||||
self.db.add(api_key_record)
|
||||
self.db.flush()
|
||||
|
||||
# Log creation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant_id,
|
||||
event_type="api_key_created",
|
||||
event_category="security",
|
||||
actor_id=created_by,
|
||||
actor_type="user",
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key_record.id),
|
||||
new_values={
|
||||
"key_id": key_id,
|
||||
"name": name,
|
||||
"permissions": permissions,
|
||||
"rate_limit": rate_limit
|
||||
}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Created API key {key_id} for tenant {tenant_id}")
|
||||
|
||||
# Return the key (only time it's shown)
|
||||
api_key_record.api_key = api_key
|
||||
return api_key_record
|
||||
|
||||
async def revoke_api_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
key_id: str,
|
||||
actor_id: str = "system"
|
||||
) -> bool:
|
||||
"""Revoke an API key"""
|
||||
|
||||
stmt = select(TenantApiKey).where(
|
||||
and_(
|
||||
TenantApiKey.tenant_id == tenant_id,
|
||||
TenantApiKey.key_id == key_id,
|
||||
TenantApiKey.is_active == True
|
||||
)
|
||||
)
|
||||
api_key = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
api_key.is_active = False
|
||||
api_key.revoked_at = datetime.utcnow()
|
||||
|
||||
# Log revocation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant_id,
|
||||
event_type="api_key_revoked",
|
||||
event_category="security",
|
||||
actor_id=actor_id,
|
||||
actor_type="user",
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key.id),
|
||||
old_values={"key_id": key_id, "is_active": True}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Revoked API key {key_id} for tenant {tenant_id}")
|
||||
|
||||
return True
|
||||
|
||||
async def get_tenant_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
resource_type: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get usage statistics for a tenant"""
|
||||
|
||||
from ..models.multitenant import UsageRecord
|
||||
|
||||
# Default to last 30 days
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
if not start_date:
|
||||
start_date = end_date - timedelta(days=30)
|
||||
|
||||
# Build query
|
||||
stmt = select(
|
||||
UsageRecord.resource_type,
|
||||
func.sum(UsageRecord.quantity).label("total_quantity"),
|
||||
func.sum(UsageRecord.total_cost).label("total_cost"),
|
||||
func.count(UsageRecord.id).label("record_count")
|
||||
).where(
|
||||
and_(
|
||||
UsageRecord.tenant_id == tenant_id,
|
||||
UsageRecord.usage_start >= start_date,
|
||||
UsageRecord.usage_end <= end_date
|
||||
)
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
stmt = stmt.where(UsageRecord.resource_type == resource_type)
|
||||
|
||||
stmt = stmt.group_by(UsageRecord.resource_type)
|
||||
|
||||
results = self.db.execute(stmt).all()
|
||||
|
||||
# Format results
|
||||
usage = {
|
||||
"period": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
},
|
||||
"by_resource": {}
|
||||
}
|
||||
|
||||
for result in results:
|
||||
usage["by_resource"][result.resource_type] = {
|
||||
"quantity": float(result.total_quantity),
|
||||
"cost": float(result.total_cost),
|
||||
"records": result.record_count
|
||||
}
|
||||
|
||||
return usage
|
||||
|
||||
async def get_tenant_quotas(self, tenant_id: str) -> List[TenantQuota]:
|
||||
"""Get all quotas for a tenant"""
|
||||
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.is_active == True
|
||||
)
|
||||
)
|
||||
|
||||
return self.db.execute(stmt).scalars().all()
|
||||
|
||||
async def check_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
resource_type: str,
|
||||
quantity: float
|
||||
) -> bool:
|
||||
"""Check if tenant has sufficient quota for a resource"""
|
||||
|
||||
# Get current quota
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True,
|
||||
TenantQuota.period_start <= datetime.utcnow(),
|
||||
TenantQuota.period_end >= datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
quota = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not quota:
|
||||
# No quota set, deny by default
|
||||
return False
|
||||
|
||||
# Check if usage + quantity exceeds limit
|
||||
if quota.used_value + quantity > quota.limit_value:
|
||||
raise QuotaExceededError(
|
||||
f"Quota exceeded for {resource_type}: "
|
||||
f"{quota.used_value + quantity}/{quota.limit_value}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def update_quota_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
resource_type: str,
|
||||
quantity: float
|
||||
):
|
||||
"""Update quota usage for a tenant"""
|
||||
|
||||
# Get current quota
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True,
|
||||
TenantQuota.period_start <= datetime.utcnow(),
|
||||
TenantQuota.period_end >= datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
quota = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if quota:
|
||||
quota.used_value += quantity
|
||||
self.db.commit()
|
||||
|
||||
# Private methods
|
||||
|
||||
def _generate_slug(self, name: str) -> str:
|
||||
"""Generate a unique slug from name"""
|
||||
import re
|
||||
# Convert to lowercase and replace spaces with hyphens
|
||||
base = re.sub(r'[^a-z0-9]+', '-', name.lower()).strip('-')
|
||||
# Add random suffix for uniqueness
|
||||
suffix = secrets.token_urlsafe(4)
|
||||
return f"{base}-{suffix}"
|
||||
|
||||
async def _tenant_exists(self, slug: Optional[str] = None, domain: Optional[str] = None) -> bool:
|
||||
"""Check if tenant exists by slug or domain"""
|
||||
|
||||
conditions = []
|
||||
if slug:
|
||||
conditions.append(Tenant.slug == slug)
|
||||
if domain:
|
||||
conditions.append(Tenant.domain == domain)
|
||||
|
||||
if not conditions:
|
||||
return False
|
||||
|
||||
stmt = select(func.count(Tenant.id)).where(or_(*conditions))
|
||||
count = self.db.execute(stmt).scalar()
|
||||
|
||||
return count > 0
|
||||
|
||||
async def _create_default_quotas(self, tenant_id: str, plan: str):
|
||||
"""Create default quotas based on plan"""
|
||||
|
||||
# Define quota templates by plan
|
||||
quota_templates = {
|
||||
"trial": {
|
||||
"gpu_hours": {"limit": 100, "period": "monthly"},
|
||||
"storage_gb": {"limit": 10, "period": "monthly"},
|
||||
"api_calls": {"limit": 10000, "period": "monthly"}
|
||||
},
|
||||
"basic": {
|
||||
"gpu_hours": {"limit": 500, "period": "monthly"},
|
||||
"storage_gb": {"limit": 100, "period": "monthly"},
|
||||
"api_calls": {"limit": 100000, "period": "monthly"}
|
||||
},
|
||||
"pro": {
|
||||
"gpu_hours": {"limit": 2000, "period": "monthly"},
|
||||
"storage_gb": {"limit": 1000, "period": "monthly"},
|
||||
"api_calls": {"limit": 1000000, "period": "monthly"}
|
||||
},
|
||||
"enterprise": {
|
||||
"gpu_hours": {"limit": 10000, "period": "monthly"},
|
||||
"storage_gb": {"limit": 10000, "period": "monthly"},
|
||||
"api_calls": {"limit": 10000000, "period": "monthly"}
|
||||
}
|
||||
}
|
||||
|
||||
quotas = quota_templates.get(plan, quota_templates["trial"])
|
||||
|
||||
# Create quota records
|
||||
now = datetime.utcnow()
|
||||
period_end = now.replace(day=1) + timedelta(days=32) # Next month
|
||||
period_end = period_end.replace(day=1) - timedelta(days=1) # Last day of current month
|
||||
|
||||
for resource_type, config in quotas.items():
|
||||
quota = TenantQuota(
|
||||
tenant_id=tenant_id,
|
||||
resource_type=resource_type,
|
||||
limit_value=config["limit"],
|
||||
used_value=0,
|
||||
period_type=config["period"],
|
||||
period_start=now,
|
||||
period_end=period_end
|
||||
)
|
||||
self.db.add(quota)
|
||||
|
||||
async def _revoke_all_api_keys(self, tenant_id: str):
|
||||
"""Revoke all API keys for a tenant"""
|
||||
|
||||
stmt = update(TenantApiKey).where(
|
||||
and_(
|
||||
TenantApiKey.tenant_id == tenant_id,
|
||||
TenantApiKey.is_active == True
|
||||
)
|
||||
).values(
|
||||
is_active=False,
|
||||
revoked_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.execute(stmt)
|
||||
|
||||
async def _log_audit_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
event_type: str,
|
||||
event_category: str,
|
||||
actor_id: str,
|
||||
actor_type: str,
|
||||
resource_type: str,
|
||||
resource_id: Optional[str] = None,
|
||||
old_values: Optional[Dict[str, Any]] = None,
|
||||
new_values: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log an audit event"""
|
||||
|
||||
audit_log = TenantAuditLog(
|
||||
tenant_id=tenant_id,
|
||||
event_type=event_type,
|
||||
event_category=event_category,
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
old_values=old_values,
|
||||
new_values=new_values,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
self.db.add(audit_log)
|
||||
654
apps/coordinator-api/src/app/services/usage_tracking.py
Normal file
654
apps/coordinator-api/src/app/services/usage_tracking.py
Normal file
@@ -0,0 +1,654 @@
|
||||
"""
|
||||
Usage tracking and billing metrics service for multi-tenant AITBC coordinator
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, update, and_, or_, func, desc
|
||||
from dataclasses import dataclass, asdict
|
||||
from decimal import Decimal
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from ..models.multitenant import (
|
||||
UsageRecord, Invoice, Tenant, TenantQuota,
|
||||
TenantMetric
|
||||
)
|
||||
from ..exceptions import BillingError, TenantError
|
||||
from ..middleware.tenant_context import get_current_tenant_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageSummary:
|
||||
"""Usage summary for billing period"""
|
||||
tenant_id: str
|
||||
period_start: datetime
|
||||
period_end: datetime
|
||||
resources: Dict[str, Dict[str, Any]]
|
||||
total_cost: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class BillingEvent:
|
||||
"""Billing event for processing"""
|
||||
tenant_id: str
|
||||
event_type: str # usage, quota_adjustment, credit, charge
|
||||
resource_type: Optional[str]
|
||||
quantity: Decimal
|
||||
unit_price: Decimal
|
||||
total_amount: Decimal
|
||||
currency: str
|
||||
timestamp: datetime
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking usage and generating billing metrics"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# Pricing configuration
|
||||
self.pricing_config = {
|
||||
"gpu_hours": {"unit_price": Decimal("0.50"), "tiered": True},
|
||||
"storage_gb": {"unit_price": Decimal("0.02"), "tiered": True},
|
||||
"api_calls": {"unit_price": Decimal("0.0001"), "tiered": False},
|
||||
"bandwidth_gb": {"unit_price": Decimal("0.01"), "tiered": False},
|
||||
"compute_hours": {"unit_price": Decimal("0.30"), "tiered": True}
|
||||
}
|
||||
|
||||
# Tier pricing thresholds
|
||||
self.tier_thresholds = {
|
||||
"gpu_hours": [
|
||||
{"min": 0, "max": 100, "multiplier": 1.0},
|
||||
{"min": 101, "max": 500, "multiplier": 0.9},
|
||||
{"min": 501, "max": 2000, "multiplier": 0.8},
|
||||
{"min": 2001, "max": None, "multiplier": 0.7}
|
||||
],
|
||||
"storage_gb": [
|
||||
{"min": 0, "max": 100, "multiplier": 1.0},
|
||||
{"min": 101, "max": 1000, "multiplier": 0.85},
|
||||
{"min": 1001, "max": 10000, "multiplier": 0.75},
|
||||
{"min": 10001, "max": None, "multiplier": 0.65}
|
||||
],
|
||||
"compute_hours": [
|
||||
{"min": 0, "max": 200, "multiplier": 1.0},
|
||||
{"min": 201, "max": 1000, "multiplier": 0.9},
|
||||
{"min": 1001, "max": 5000, "multiplier": 0.8},
|
||||
{"min": 5001, "max": None, "multiplier": 0.7}
|
||||
]
|
||||
}
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
resource_type: str,
|
||||
quantity: Decimal,
|
||||
unit_price: Optional[Decimal] = None,
|
||||
job_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> UsageRecord:
|
||||
"""Record usage for billing"""
|
||||
|
||||
# Calculate unit price if not provided
|
||||
if not unit_price:
|
||||
unit_price = await self._calculate_unit_price(resource_type, quantity)
|
||||
|
||||
# Calculate total cost
|
||||
total_cost = unit_price * quantity
|
||||
|
||||
# Create usage record
|
||||
usage_record = UsageRecord(
|
||||
tenant_id=tenant_id,
|
||||
resource_type=resource_type,
|
||||
quantity=quantity,
|
||||
unit=self._get_unit_for_resource(resource_type),
|
||||
unit_price=unit_price,
|
||||
total_cost=total_cost,
|
||||
currency="USD",
|
||||
usage_start=datetime.utcnow(),
|
||||
usage_end=datetime.utcnow(),
|
||||
job_id=job_id,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
self.db.add(usage_record)
|
||||
self.db.commit()
|
||||
|
||||
# Emit billing event
|
||||
await self._emit_billing_event(BillingEvent(
|
||||
tenant_id=tenant_id,
|
||||
event_type="usage",
|
||||
resource_type=resource_type,
|
||||
quantity=quantity,
|
||||
unit_price=unit_price,
|
||||
total_amount=total_cost,
|
||||
currency="USD",
|
||||
timestamp=datetime.utcnow(),
|
||||
metadata=metadata or {}
|
||||
))
|
||||
|
||||
self.logger.info(
|
||||
f"Recorded usage: tenant={tenant_id}, "
|
||||
f"resource={resource_type}, quantity={quantity}, cost={total_cost}"
|
||||
)
|
||||
|
||||
return usage_record
|
||||
|
||||
async def get_usage_summary(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
resource_type: Optional[str] = None
|
||||
) -> UsageSummary:
|
||||
"""Get usage summary for a billing period"""
|
||||
|
||||
# Build query
|
||||
stmt = select(
|
||||
UsageRecord.resource_type,
|
||||
func.sum(UsageRecord.quantity).label("total_quantity"),
|
||||
func.sum(UsageRecord.total_cost).label("total_cost"),
|
||||
func.count(UsageRecord.id).label("record_count"),
|
||||
func.avg(UsageRecord.unit_price).label("avg_unit_price")
|
||||
).where(
|
||||
and_(
|
||||
UsageRecord.tenant_id == tenant_id,
|
||||
UsageRecord.usage_start >= start_date,
|
||||
UsageRecord.usage_end <= end_date
|
||||
)
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
stmt = stmt.where(UsageRecord.resource_type == resource_type)
|
||||
|
||||
stmt = stmt.group_by(UsageRecord.resource_type)
|
||||
|
||||
results = self.db.execute(stmt).all()
|
||||
|
||||
# Build summary
|
||||
resources = {}
|
||||
total_cost = Decimal("0")
|
||||
|
||||
for result in results:
|
||||
resources[result.resource_type] = {
|
||||
"quantity": float(result.total_quantity),
|
||||
"cost": float(result.total_cost),
|
||||
"records": result.record_count,
|
||||
"avg_unit_price": float(result.avg_unit_price)
|
||||
}
|
||||
total_cost += Decimal(str(result.total_cost))
|
||||
|
||||
return UsageSummary(
|
||||
tenant_id=tenant_id,
|
||||
period_start=start_date,
|
||||
period_end=end_date,
|
||||
resources=resources,
|
||||
total_cost=total_cost,
|
||||
currency="USD"
|
||||
)
|
||||
|
||||
async def generate_invoice(
|
||||
self,
|
||||
tenant_id: str,
|
||||
period_start: datetime,
|
||||
period_end: datetime,
|
||||
due_days: int = 30
|
||||
) -> Invoice:
|
||||
"""Generate invoice for billing period"""
|
||||
|
||||
# Check if invoice already exists
|
||||
existing = await self._get_existing_invoice(tenant_id, period_start, period_end)
|
||||
if existing:
|
||||
raise BillingError(f"Invoice already exists for period {period_start} to {period_end}")
|
||||
|
||||
# Get usage summary
|
||||
summary = await self.get_usage_summary(tenant_id, period_start, period_end)
|
||||
|
||||
# Generate invoice number
|
||||
invoice_number = await self._generate_invoice_number(tenant_id)
|
||||
|
||||
# Calculate line items
|
||||
line_items = []
|
||||
subtotal = Decimal("0")
|
||||
|
||||
for resource_type, usage in summary.resources.items():
|
||||
line_item = {
|
||||
"description": f"{resource_type.replace('_', ' ').title()} Usage",
|
||||
"quantity": usage["quantity"],
|
||||
"unit_price": usage["avg_unit_price"],
|
||||
"amount": usage["cost"]
|
||||
}
|
||||
line_items.append(line_item)
|
||||
subtotal += Decimal(str(usage["cost"]))
|
||||
|
||||
# Calculate tax (example: 10% for digital services)
|
||||
tax_rate = Decimal("0.10")
|
||||
tax_amount = subtotal * tax_rate
|
||||
total_amount = subtotal + tax_amount
|
||||
|
||||
# Create invoice
|
||||
invoice = Invoice(
|
||||
tenant_id=tenant_id,
|
||||
invoice_number=invoice_number,
|
||||
status="draft",
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
due_date=period_end + timedelta(days=due_days),
|
||||
subtotal=subtotal,
|
||||
tax_amount=tax_amount,
|
||||
total_amount=total_amount,
|
||||
currency="USD",
|
||||
line_items=line_items
|
||||
)
|
||||
|
||||
self.db.add(invoice)
|
||||
self.db.commit()
|
||||
|
||||
self.logger.info(
|
||||
f"Generated invoice {invoice_number} for tenant {tenant_id}: "
|
||||
f"${total_amount}"
|
||||
)
|
||||
|
||||
return invoice
|
||||
|
||||
async def get_billing_metrics(
|
||||
self,
|
||||
tenant_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get billing metrics and analytics"""
|
||||
|
||||
# Default to last 30 days
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
if not start_date:
|
||||
start_date = end_date - timedelta(days=30)
|
||||
|
||||
# Build base query
|
||||
base_conditions = [
|
||||
UsageRecord.usage_start >= start_date,
|
||||
UsageRecord.usage_end <= end_date
|
||||
]
|
||||
|
||||
if tenant_id:
|
||||
base_conditions.append(UsageRecord.tenant_id == tenant_id)
|
||||
|
||||
# Total usage and cost
|
||||
stmt = select(
|
||||
func.sum(UsageRecord.quantity).label("total_quantity"),
|
||||
func.sum(UsageRecord.total_cost).label("total_cost"),
|
||||
func.count(UsageRecord.id).label("total_records"),
|
||||
func.count(func.distinct(UsageRecord.tenant_id)).label("active_tenants")
|
||||
).where(and_(*base_conditions))
|
||||
|
||||
totals = self.db.execute(stmt).first()
|
||||
|
||||
# Usage by resource type
|
||||
stmt = select(
|
||||
UsageRecord.resource_type,
|
||||
func.sum(UsageRecord.quantity).label("quantity"),
|
||||
func.sum(UsageRecord.total_cost).label("cost")
|
||||
).where(and_(*base_conditions)).group_by(UsageRecord.resource_type)
|
||||
|
||||
by_resource = self.db.execute(stmt).all()
|
||||
|
||||
# Top tenants by usage
|
||||
if not tenant_id:
|
||||
stmt = select(
|
||||
UsageRecord.tenant_id,
|
||||
func.sum(UsageRecord.total_cost).label("total_cost")
|
||||
).where(and_(*base_conditions)).group_by(
|
||||
UsageRecord.tenant_id
|
||||
).order_by(desc("total_cost")).limit(10)
|
||||
|
||||
top_tenants = self.db.execute(stmt).all()
|
||||
else:
|
||||
top_tenants = []
|
||||
|
||||
# Daily usage trend
|
||||
stmt = select(
|
||||
func.date(UsageRecord.usage_start).label("date"),
|
||||
func.sum(UsageRecord.total_cost).label("daily_cost")
|
||||
).where(and_(*base_conditions)).group_by(
|
||||
func.date(UsageRecord.usage_start)
|
||||
).order_by("date")
|
||||
|
||||
daily_trend = self.db.execute(stmt).all()
|
||||
|
||||
# Assemble metrics
|
||||
metrics = {
|
||||
"period": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
},
|
||||
"totals": {
|
||||
"quantity": float(totals.total_quantity or 0),
|
||||
"cost": float(totals.total_cost or 0),
|
||||
"records": totals.total_records or 0,
|
||||
"active_tenants": totals.active_tenants or 0
|
||||
},
|
||||
"by_resource": {
|
||||
r.resource_type: {
|
||||
"quantity": float(r.quantity),
|
||||
"cost": float(r.cost)
|
||||
}
|
||||
for r in by_resource
|
||||
},
|
||||
"top_tenants": [
|
||||
{
|
||||
"tenant_id": str(t.tenant_id),
|
||||
"cost": float(t.total_cost)
|
||||
}
|
||||
for t in top_tenants
|
||||
],
|
||||
"daily_trend": [
|
||||
{
|
||||
"date": d.date.isoformat(),
|
||||
"cost": float(d.daily_cost)
|
||||
}
|
||||
for d in daily_trend
|
||||
]
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
async def process_billing_events(self, events: List[BillingEvent]) -> bool:
|
||||
"""Process batch of billing events"""
|
||||
|
||||
try:
|
||||
for event in events:
|
||||
if event.event_type == "usage":
|
||||
# Already recorded in record_usage
|
||||
continue
|
||||
elif event.event_type == "credit":
|
||||
await self._apply_credit(event)
|
||||
elif event.event_type == "charge":
|
||||
await self._apply_charge(event)
|
||||
elif event.event_type == "quota_adjustment":
|
||||
await self._adjust_quota(event)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process billing events: {e}")
|
||||
return False
|
||||
|
||||
async def export_usage_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
format: str = "csv"
|
||||
) -> str:
|
||||
"""Export usage data in specified format"""
|
||||
|
||||
# Get usage records
|
||||
stmt = select(UsageRecord).where(
|
||||
and_(
|
||||
UsageRecord.tenant_id == tenant_id,
|
||||
UsageRecord.usage_start >= start_date,
|
||||
UsageRecord.usage_end <= end_date
|
||||
)
|
||||
).order_by(UsageRecord.usage_start)
|
||||
|
||||
records = self.db.execute(stmt).scalars().all()
|
||||
|
||||
if format == "csv":
|
||||
return await self._export_csv(records)
|
||||
elif format == "json":
|
||||
return await self._export_json(records)
|
||||
else:
|
||||
raise BillingError(f"Unsupported export format: {format}")
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _calculate_unit_price(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: Decimal
|
||||
) -> Decimal:
|
||||
"""Calculate unit price with tiered pricing"""
|
||||
|
||||
config = self.pricing_config.get(resource_type)
|
||||
if not config:
|
||||
return Decimal("0")
|
||||
|
||||
base_price = config["unit_price"]
|
||||
|
||||
if not config.get("tiered", False):
|
||||
return base_price
|
||||
|
||||
# Find applicable tier
|
||||
tiers = self.tier_thresholds.get(resource_type, [])
|
||||
quantity_float = float(quantity)
|
||||
|
||||
for tier in tiers:
|
||||
if (tier["min"] is None or quantity_float >= tier["min"]) and \
|
||||
(tier["max"] is None or quantity_float <= tier["max"]):
|
||||
return base_price * Decimal(str(tier["multiplier"]))
|
||||
|
||||
# Default to highest tier
|
||||
return base_price * Decimal("0.5")
|
||||
|
||||
def _get_unit_for_resource(self, resource_type: str) -> str:
|
||||
"""Get unit for resource type"""
|
||||
unit_map = {
|
||||
"gpu_hours": "hours",
|
||||
"storage_gb": "gb",
|
||||
"api_calls": "calls",
|
||||
"bandwidth_gb": "gb",
|
||||
"compute_hours": "hours"
|
||||
}
|
||||
return unit_map.get(resource_type, "units")
|
||||
|
||||
async def _emit_billing_event(self, event: BillingEvent):
|
||||
"""Emit billing event for processing"""
|
||||
# In a real implementation, this would publish to a message queue
|
||||
# For now, we'll just log it
|
||||
self.logger.debug(f"Emitting billing event: {event}")
|
||||
|
||||
async def _get_existing_invoice(
|
||||
self,
|
||||
tenant_id: str,
|
||||
period_start: datetime,
|
||||
period_end: datetime
|
||||
) -> Optional[Invoice]:
|
||||
"""Check if invoice already exists for period"""
|
||||
|
||||
stmt = select(Invoice).where(
|
||||
and_(
|
||||
Invoice.tenant_id == tenant_id,
|
||||
Invoice.period_start == period_start,
|
||||
Invoice.period_end == period_end
|
||||
)
|
||||
)
|
||||
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def _generate_invoice_number(self, tenant_id: str) -> str:
|
||||
"""Generate unique invoice number"""
|
||||
|
||||
# Get tenant info
|
||||
stmt = select(Tenant).where(Tenant.id == tenant_id)
|
||||
tenant = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
# Generate number: INV-{tenant.slug}-{YYYYMMDD}-{seq}
|
||||
date_str = datetime.utcnow().strftime("%Y%m%d")
|
||||
|
||||
# Get sequence for today
|
||||
seq_key = f"invoice_seq:{tenant_id}:{date_str}"
|
||||
# In a real implementation, use Redis or sequence table
|
||||
# For now, use a simple counter
|
||||
stmt = select(func.count(Invoice.id)).where(
|
||||
and_(
|
||||
Invoice.tenant_id == tenant_id,
|
||||
func.date(Invoice.created_at) == func.current_date()
|
||||
)
|
||||
)
|
||||
seq = self.db.execute(stmt).scalar() + 1
|
||||
|
||||
return f"INV-{tenant.slug}-{date_str}-{seq:04d}"
|
||||
|
||||
async def _apply_credit(self, event: BillingEvent):
|
||||
"""Apply credit to tenant account"""
|
||||
# TODO: Implement credit application
|
||||
pass
|
||||
|
||||
async def _apply_charge(self, event: BillingEvent):
|
||||
"""Apply charge to tenant account"""
|
||||
# TODO: Implement charge application
|
||||
pass
|
||||
|
||||
async def _adjust_quota(self, event: BillingEvent):
|
||||
"""Adjust quota based on billing event"""
|
||||
# TODO: Implement quota adjustment
|
||||
pass
|
||||
|
||||
async def _export_csv(self, records: List[UsageRecord]) -> str:
|
||||
"""Export records to CSV"""
|
||||
import csv
|
||||
import io
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Header
|
||||
writer.writerow([
|
||||
"Timestamp", "Resource Type", "Quantity", "Unit",
|
||||
"Unit Price", "Total Cost", "Currency", "Job ID"
|
||||
])
|
||||
|
||||
# Data rows
|
||||
for record in records:
|
||||
writer.writerow([
|
||||
record.usage_start.isoformat(),
|
||||
record.resource_type,
|
||||
record.quantity,
|
||||
record.unit,
|
||||
record.unit_price,
|
||||
record.total_cost,
|
||||
record.currency,
|
||||
record.job_id or ""
|
||||
])
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
async def _export_json(self, records: List[UsageRecord]) -> str:
|
||||
"""Export records to JSON"""
|
||||
import json
|
||||
|
||||
data = []
|
||||
for record in records:
|
||||
data.append({
|
||||
"timestamp": record.usage_start.isoformat(),
|
||||
"resource_type": record.resource_type,
|
||||
"quantity": float(record.quantity),
|
||||
"unit": record.unit,
|
||||
"unit_price": float(record.unit_price),
|
||||
"total_cost": float(record.total_cost),
|
||||
"currency": record.currency,
|
||||
"job_id": record.job_id,
|
||||
"metadata": record.metadata
|
||||
})
|
||||
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
|
||||
class BillingScheduler:
|
||||
"""Scheduler for automated billing processes"""
|
||||
|
||||
def __init__(self, usage_service: UsageTrackingService):
|
||||
self.usage_service = usage_service
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start billing scheduler"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.logger.info("Billing scheduler started")
|
||||
|
||||
# Schedule daily tasks
|
||||
asyncio.create_task(self._daily_tasks())
|
||||
|
||||
# Schedule monthly invoicing
|
||||
asyncio.create_task(self._monthly_invoicing())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop billing scheduler"""
|
||||
self.running = False
|
||||
self.logger.info("Billing scheduler stopped")
|
||||
|
||||
async def _daily_tasks(self):
|
||||
"""Run daily billing tasks"""
|
||||
while self.running:
|
||||
try:
|
||||
# Reset quotas for new periods
|
||||
await self._reset_daily_quotas()
|
||||
|
||||
# Process pending billing events
|
||||
await self._process_pending_events()
|
||||
|
||||
# Wait until next day
|
||||
now = datetime.utcnow()
|
||||
next_day = (now + timedelta(days=1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
sleep_seconds = (next_day - now).total_seconds()
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in daily tasks: {e}")
|
||||
await asyncio.sleep(3600) # Retry in 1 hour
|
||||
|
||||
async def _monthly_invoicing(self):
|
||||
"""Generate monthly invoices"""
|
||||
while self.running:
|
||||
try:
|
||||
# Wait until first day of month
|
||||
now = datetime.utcnow()
|
||||
if now.day != 1:
|
||||
next_month = now.replace(day=1) + timedelta(days=32)
|
||||
next_month = next_month.replace(day=1)
|
||||
sleep_seconds = (next_month - now).total_seconds()
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
continue
|
||||
|
||||
# Generate invoices for all active tenants
|
||||
await self._generate_monthly_invoices()
|
||||
|
||||
# Wait until next month
|
||||
next_month = now.replace(day=1) + timedelta(days=32)
|
||||
next_month = next_month.replace(day=1)
|
||||
sleep_seconds = (next_month - now).total_seconds()
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in monthly invoicing: {e}")
|
||||
await asyncio.sleep(86400) # Retry in 1 day
|
||||
|
||||
async def _reset_daily_quotas(self):
|
||||
"""Reset daily quotas"""
|
||||
# TODO: Implement daily quota reset
|
||||
pass
|
||||
|
||||
async def _process_pending_events(self):
|
||||
"""Process pending billing events"""
|
||||
# TODO: Implement event processing
|
||||
pass
|
||||
|
||||
async def _generate_monthly_invoices(self):
|
||||
"""Generate invoices for all tenants"""
|
||||
# TODO: Implement monthly invoice generation
|
||||
pass
|
||||
269
apps/coordinator-api/src/app/services/zk_proofs.py
Normal file
269
apps/coordinator-api/src/app/services/zk_proofs.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
ZK Proof generation service for privacy-preserving receipt attestation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from ..models import Receipt, JobResult
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ZKProofService:
|
||||
"""Service for generating zero-knowledge proofs for receipts"""
|
||||
|
||||
def __init__(self):
|
||||
self.circuits_dir = Path(__file__).parent.parent.parent.parent / "apps" / "zk-circuits"
|
||||
self.zkey_path = self.circuits_dir / "receipt_0001.zkey"
|
||||
self.wasm_path = self.circuits_dir / "receipt.wasm"
|
||||
self.vkey_path = self.circuits_dir / "verification_key.json"
|
||||
|
||||
# Verify circuit files exist
|
||||
if not all(p.exists() for p in [self.zkey_path, self.wasm_path, self.vkey_path]):
|
||||
logger.warning("ZK circuit files not found. Proof generation disabled.")
|
||||
self.enabled = False
|
||||
else:
|
||||
self.enabled = True
|
||||
|
||||
async def generate_receipt_proof(
|
||||
self,
|
||||
receipt: Receipt,
|
||||
job_result: JobResult,
|
||||
privacy_level: str = "basic"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Generate a ZK proof for a receipt"""
|
||||
|
||||
if not self.enabled:
|
||||
logger.warning("ZK proof generation not available")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Prepare circuit inputs based on privacy level
|
||||
inputs = await self._prepare_inputs(receipt, job_result, privacy_level)
|
||||
|
||||
# Generate proof using snarkjs
|
||||
proof_data = await self._generate_proof(inputs)
|
||||
|
||||
# Return proof with verification data
|
||||
return {
|
||||
"proof": proof_data["proof"],
|
||||
"public_signals": proof_data["publicSignals"],
|
||||
"privacy_level": privacy_level,
|
||||
"circuit_hash": await self._get_circuit_hash()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate ZK proof: {e}")
|
||||
return None
|
||||
|
||||
async def _prepare_inputs(
|
||||
self,
|
||||
receipt: Receipt,
|
||||
job_result: JobResult,
|
||||
privacy_level: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare circuit inputs based on privacy level"""
|
||||
|
||||
if privacy_level == "basic":
|
||||
# Hide computation details, reveal settlement amount
|
||||
return {
|
||||
"data": [
|
||||
str(receipt.job_id),
|
||||
str(receipt.miner_id),
|
||||
str(job_result.result_hash),
|
||||
str(receipt.pricing.rate)
|
||||
],
|
||||
"hash": await self._hash_receipt(receipt)
|
||||
}
|
||||
|
||||
elif privacy_level == "enhanced":
|
||||
# Hide all amounts, prove correctness
|
||||
return {
|
||||
"settlementAmount": receipt.settlement_amount,
|
||||
"timestamp": receipt.timestamp,
|
||||
"receipt": self._serialize_receipt(receipt),
|
||||
"computationResult": job_result.result_hash,
|
||||
"pricingRate": receipt.pricing.rate,
|
||||
"minerReward": receipt.miner_reward,
|
||||
"coordinatorFee": receipt.coordinator_fee
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown privacy level: {privacy_level}")
|
||||
|
||||
async def _hash_receipt(self, receipt: Receipt) -> str:
|
||||
"""Hash receipt for public verification"""
|
||||
# In a real implementation, use Poseidon or the same hash as circuit
|
||||
import hashlib
|
||||
|
||||
receipt_data = {
|
||||
"job_id": receipt.job_id,
|
||||
"miner_id": receipt.miner_id,
|
||||
"timestamp": receipt.timestamp,
|
||||
"pricing": receipt.pricing.dict()
|
||||
}
|
||||
|
||||
receipt_str = json.dumps(receipt_data, sort_keys=True)
|
||||
return hashlib.sha256(receipt_str.encode()).hexdigest()
|
||||
|
||||
def _serialize_receipt(self, receipt: Receipt) -> List[str]:
|
||||
"""Serialize receipt for circuit input"""
|
||||
# Convert receipt to field elements for circuit
|
||||
return [
|
||||
str(receipt.job_id)[:32], # Truncate for field size
|
||||
str(receipt.miner_id)[:32],
|
||||
str(receipt.timestamp)[:32],
|
||||
str(receipt.settlement_amount)[:32],
|
||||
str(receipt.miner_reward)[:32],
|
||||
str(receipt.coordinator_fee)[:32],
|
||||
"0", "0" # Padding
|
||||
]
|
||||
|
||||
async def _generate_proof(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate proof using snarkjs"""
|
||||
|
||||
# Write inputs to temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(inputs, f)
|
||||
inputs_file = f.name
|
||||
|
||||
try:
|
||||
# Create Node.js script for proof generation
|
||||
script = f"""
|
||||
const snarkjs = require('snarkjs');
|
||||
const fs = require('fs');
|
||||
|
||||
async function main() {{
|
||||
try {{
|
||||
// Load inputs
|
||||
const inputs = JSON.parse(fs.readFileSync('{inputs_file}', 'utf8'));
|
||||
|
||||
// Load circuit
|
||||
const wasm = fs.readFileSync('{self.wasm_path}');
|
||||
const zkey = fs.readFileSync('{self.zkey_path}');
|
||||
|
||||
// Calculate witness
|
||||
const {{ witness }} = await snarkjs.wtns.calculate(inputs, wasm, wasm);
|
||||
|
||||
// Generate proof
|
||||
const {{ proof, publicSignals }} = await snarkjs.groth16.prove(zkey, witness);
|
||||
|
||||
// Output result
|
||||
console.log(JSON.stringify({{ proof, publicSignals }}));
|
||||
}} catch (error) {{
|
||||
console.error('Error:', error);
|
||||
process.exit(1);
|
||||
}}
|
||||
}}
|
||||
|
||||
main();
|
||||
"""
|
||||
|
||||
# Write script to temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
|
||||
f.write(script)
|
||||
script_file = f.name
|
||||
|
||||
try:
|
||||
# Run script
|
||||
result = subprocess.run(
|
||||
["node", script_file],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(self.circuits_dir)
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"Proof generation failed: {result.stderr}")
|
||||
|
||||
# Parse result
|
||||
return json.loads(result.stdout)
|
||||
|
||||
finally:
|
||||
os.unlink(script_file)
|
||||
|
||||
finally:
|
||||
os.unlink(inputs_file)
|
||||
|
||||
async def _get_circuit_hash(self) -> str:
|
||||
"""Get hash of circuit for verification"""
|
||||
# In a real implementation, return the hash of the circuit
|
||||
# This ensures the proof is for the correct circuit version
|
||||
return "0x1234567890abcdef"
|
||||
|
||||
async def verify_proof(
|
||||
self,
|
||||
proof: Dict[str, Any],
|
||||
public_signals: List[str]
|
||||
) -> bool:
|
||||
"""Verify a ZK proof"""
|
||||
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load verification key
|
||||
with open(self.vkey_path) as f:
|
||||
vkey = json.load(f)
|
||||
|
||||
# Create verification script
|
||||
script = f"""
|
||||
const snarkjs = require('snarkjs');
|
||||
|
||||
async function main() {{
|
||||
try {{
|
||||
const vKey = {json.dumps(vkey)};
|
||||
const proof = {json.dumps(proof)};
|
||||
const publicSignals = {json.dumps(public_signals)};
|
||||
|
||||
const verified = await snarkjs.groth16.verify(vKey, publicSignals, proof);
|
||||
console.log(verified);
|
||||
}} catch (error) {{
|
||||
console.error('Error:', error);
|
||||
process.exit(1);
|
||||
}}
|
||||
}}
|
||||
|
||||
main();
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
|
||||
f.write(script)
|
||||
script_file = f.name
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", script_file],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(self.circuits_dir)
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Proof verification failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
return result.stdout.strip() == "true"
|
||||
|
||||
finally:
|
||||
os.unlink(script_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify proof: {e}")
|
||||
return False
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if ZK proof generation is available"""
|
||||
return self.enabled
|
||||
|
||||
|
||||
# Global instance
|
||||
zk_proof_service = ZKProofService()
|
||||
Reference in New Issue
Block a user