feat: add marketplace metrics, privacy features, and service registry endpoints
- Add Prometheus metrics for marketplace API throughput and error rates with new dashboard panels - Implement confidential transaction models with encryption support and access control - Add key management system with registration, rotation, and audit logging - Create services and registry routers for service discovery and management - Integrate ZK proof generation for privacy-preserving receipts - Add metrics instru
This commit is contained in:
406
apps/coordinator-api/aitbc/api/v1/settlement.py
Normal file
406
apps/coordinator-api/aitbc/api/v1/settlement.py
Normal file
@ -0,0 +1,406 @@
|
||||
"""
|
||||
API endpoints for cross-chain settlements
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
import asyncio
|
||||
|
||||
from ...settlement.hooks import SettlementHook
|
||||
from ...settlement.manager import BridgeManager
|
||||
from ...settlement.bridges.base import SettlementResult
|
||||
from ...auth import get_api_key
|
||||
from ...models.job import Job
|
||||
|
||||
router = APIRouter(prefix="/settlement", tags=["settlement"])
|
||||
|
||||
|
||||
class CrossChainSettlementRequest(BaseModel):
|
||||
"""Request model for cross-chain settlement"""
|
||||
job_id: str = Field(..., description="ID of the job to settle")
|
||||
target_chain_id: int = Field(..., description="Target blockchain chain ID")
|
||||
bridge_name: Optional[str] = Field(None, description="Specific bridge to use")
|
||||
priority: str = Field("cost", description="Settlement priority: 'cost' or 'speed'")
|
||||
privacy_level: Optional[str] = Field(None, description="Privacy level: 'basic' or 'enhanced'")
|
||||
use_zk_proof: bool = Field(False, description="Use zero-knowledge proof for privacy")
|
||||
|
||||
|
||||
class SettlementEstimateRequest(BaseModel):
|
||||
"""Request model for settlement cost estimation"""
|
||||
job_id: str = Field(..., description="ID of the job")
|
||||
target_chain_id: int = Field(..., description="Target blockchain chain ID")
|
||||
bridge_name: Optional[str] = Field(None, description="Specific bridge to use")
|
||||
|
||||
|
||||
class BatchSettlementRequest(BaseModel):
|
||||
"""Request model for batch settlement"""
|
||||
job_ids: List[str] = Field(..., description="List of job IDs to settle")
|
||||
target_chain_id: int = Field(..., description="Target blockchain chain ID")
|
||||
bridge_name: Optional[str] = Field(None, description="Specific bridge to use")
|
||||
|
||||
|
||||
class SettlementResponse(BaseModel):
|
||||
"""Response model for settlement operations"""
|
||||
message_id: str = Field(..., description="Settlement message ID")
|
||||
status: str = Field(..., description="Settlement status")
|
||||
transaction_hash: Optional[str] = Field(None, description="Transaction hash")
|
||||
bridge_name: str = Field(..., description="Bridge used")
|
||||
estimated_completion: Optional[str] = Field(None, description="Estimated completion time")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
class CostEstimateResponse(BaseModel):
|
||||
"""Response model for cost estimates"""
|
||||
bridge_costs: Dict[str, Dict[str, Any]] = Field(..., description="Costs by bridge")
|
||||
recommended_bridge: str = Field(..., description="Recommended bridge")
|
||||
total_estimates: Dict[str, float] = Field(..., description="Min/Max/Average costs")
|
||||
|
||||
|
||||
def get_settlement_hook() -> SettlementHook:
|
||||
"""Dependency injection for settlement hook"""
|
||||
# This would be properly injected in the app setup
|
||||
from ...main import settlement_hook
|
||||
return settlement_hook
|
||||
|
||||
|
||||
def get_bridge_manager() -> BridgeManager:
|
||||
"""Dependency injection for bridge manager"""
|
||||
# This would be properly injected in the app setup
|
||||
from ...main import bridge_manager
|
||||
return bridge_manager
|
||||
|
||||
|
||||
@router.post("/cross-chain", response_model=SettlementResponse)
|
||||
async def initiate_cross_chain_settlement(
|
||||
request: CrossChainSettlementRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
settlement_hook: SettlementHook = Depends(get_settlement_hook)
|
||||
):
|
||||
"""
|
||||
Initiate cross-chain settlement for a completed job
|
||||
|
||||
This endpoint settles job receipts and payments across different blockchains
|
||||
using various bridge protocols (LayerZero, Chainlink CCIP, etc.).
|
||||
"""
|
||||
try:
|
||||
# Validate job exists and is completed
|
||||
job = await Job.get(request.job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if not job.completed:
|
||||
raise HTTPException(status_code=400, detail="Job is not completed")
|
||||
|
||||
if job.cross_chain_settlement_id:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Job already has settlement {job.cross_chain_settlement_id}"
|
||||
)
|
||||
|
||||
# Initiate settlement
|
||||
settlement_options = {}
|
||||
if request.use_zk_proof:
|
||||
settlement_options["privacy_level"] = request.privacy_level or "basic"
|
||||
settlement_options["use_zk_proof"] = True
|
||||
|
||||
result = await settlement_hook.initiate_manual_settlement(
|
||||
job_id=request.job_id,
|
||||
target_chain_id=request.target_chain_id,
|
||||
bridge_name=request.bridge_name,
|
||||
options=settlement_options
|
||||
)
|
||||
|
||||
# Add background task to monitor settlement
|
||||
background_tasks.add_task(
|
||||
monitor_settlement_completion,
|
||||
result.message_id,
|
||||
request.job_id
|
||||
)
|
||||
|
||||
return SettlementResponse(
|
||||
message_id=result.message_id,
|
||||
status=result.status.value,
|
||||
transaction_hash=result.transaction_hash,
|
||||
bridge_name=result.transaction_hash and await get_bridge_from_tx(result.transaction_hash),
|
||||
estimated_completion=estimate_completion_time(result.status),
|
||||
error_message=result.error_message
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Settlement failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{message_id}/status", response_model=SettlementResponse)
|
||||
async def get_settlement_status(
|
||||
message_id: str,
|
||||
settlement_hook: SettlementHook = Depends(get_settlement_hook)
|
||||
):
|
||||
"""Get the current status of a cross-chain settlement"""
|
||||
try:
|
||||
result = await settlement_hook.get_settlement_status(message_id)
|
||||
|
||||
# Get job info if available
|
||||
job_id = None
|
||||
if result.transaction_hash:
|
||||
job_id = await get_job_id_from_settlement(message_id)
|
||||
|
||||
return SettlementResponse(
|
||||
message_id=message_id,
|
||||
status=result.status.value,
|
||||
transaction_hash=result.transaction_hash,
|
||||
bridge_name=job_id and await get_bridge_from_job(job_id),
|
||||
estimated_completion=estimate_completion_time(result.status),
|
||||
error_message=result.error_message
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/estimate-cost", response_model=CostEstimateResponse)
|
||||
async def estimate_settlement_cost(
|
||||
request: SettlementEstimateRequest,
|
||||
settlement_hook: SettlementHook = Depends(get_settlement_hook)
|
||||
):
|
||||
"""Estimate the cost of cross-chain settlement"""
|
||||
try:
|
||||
# Get cost estimates
|
||||
estimates = await settlement_hook.estimate_settlement_cost(
|
||||
job_id=request.job_id,
|
||||
target_chain_id=request.target_chain_id,
|
||||
bridge_name=request.bridge_name
|
||||
)
|
||||
|
||||
# Calculate totals and recommendations
|
||||
valid_estimates = {
|
||||
name: cost for name, cost in estimates.items()
|
||||
if 'error' not in cost
|
||||
}
|
||||
|
||||
if not valid_estimates:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No bridges available for this settlement"
|
||||
)
|
||||
|
||||
# Find cheapest option
|
||||
cheapest_bridge = min(valid_estimates.items(), key=lambda x: x[1]['total'])
|
||||
|
||||
# Calculate statistics
|
||||
costs = [est['total'] for est in valid_estimates.values()]
|
||||
total_estimates = {
|
||||
"min": min(costs),
|
||||
"max": max(costs),
|
||||
"average": sum(costs) / len(costs)
|
||||
}
|
||||
|
||||
return CostEstimateResponse(
|
||||
bridge_costs=estimates,
|
||||
recommended_bridge=cheapest_bridge[0],
|
||||
total_estimates=total_estimates
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Estimation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/batch", response_model=List[SettlementResponse])
|
||||
async def batch_settle(
|
||||
request: BatchSettlementRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
settlement_hook: SettlementHook = Depends(get_settlement_hook)
|
||||
):
|
||||
"""Settle multiple jobs in a batch"""
|
||||
try:
|
||||
# Validate all jobs exist and are completed
|
||||
jobs = []
|
||||
for job_id in request.job_ids:
|
||||
job = await Job.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
||||
if not job.completed:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Job {job_id} is not completed"
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
# Process batch settlement
|
||||
results = []
|
||||
for job in jobs:
|
||||
try:
|
||||
result = await settlement_hook.initiate_manual_settlement(
|
||||
job_id=job.id,
|
||||
target_chain_id=request.target_chain_id,
|
||||
bridge_name=request.bridge_name
|
||||
)
|
||||
|
||||
# Add monitoring task
|
||||
background_tasks.add_task(
|
||||
monitor_settlement_completion,
|
||||
result.message_id,
|
||||
job.id
|
||||
)
|
||||
|
||||
results.append(SettlementResponse(
|
||||
message_id=result.message_id,
|
||||
status=result.status.value,
|
||||
transaction_hash=result.transaction_hash,
|
||||
bridge_name=result.transaction_hash and await get_bridge_from_tx(result.transaction_hash),
|
||||
estimated_completion=estimate_completion_time(result.status),
|
||||
error_message=result.error_message
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
results.append(SettlementResponse(
|
||||
message_id="",
|
||||
status="failed",
|
||||
transaction_hash=None,
|
||||
bridge_name="",
|
||||
estimated_completion=None,
|
||||
error_message=str(e)
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Batch settlement failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/bridges", response_model=Dict[str, Any])
|
||||
async def list_supported_bridges(
|
||||
settlement_hook: SettlementHook = Depends(get_settlement_hook)
|
||||
):
|
||||
"""List all supported bridges and their capabilities"""
|
||||
try:
|
||||
return await settlement_hook.list_supported_bridges()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list bridges: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/chains", response_model=Dict[str, List[int]])
|
||||
async def list_supported_chains(
|
||||
settlement_hook: SettlementHook = Depends(get_settlement_hook)
|
||||
):
|
||||
"""List all supported chains by bridge"""
|
||||
try:
|
||||
return await settlement_hook.list_supported_chains()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list chains: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{message_id}/refund")
|
||||
async def refund_settlement(
|
||||
message_id: str,
|
||||
bridge_manager: BridgeManager = Depends(get_bridge_manager)
|
||||
):
|
||||
"""Attempt to refund a failed settlement"""
|
||||
try:
|
||||
result = await bridge_manager.refund_failed_settlement(message_id)
|
||||
|
||||
return {
|
||||
"message_id": message_id,
|
||||
"status": result.status.value,
|
||||
"refund_transaction": result.transaction_hash,
|
||||
"error_message": result.error_message
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Refund failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/job/{job_id}/settlements")
|
||||
async def get_job_settlements(
|
||||
job_id: str,
|
||||
bridge_manager: BridgeManager = Depends(get_bridge_manager)
|
||||
):
|
||||
"""Get all cross-chain settlements for a job"""
|
||||
try:
|
||||
# Validate job exists
|
||||
job = await Job.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
# Get settlements from storage
|
||||
settlements = await bridge_manager.storage.get_settlements_by_job(job_id)
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"settlements": settlements,
|
||||
"total_count": len(settlements)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get settlements: {str(e)}")
|
||||
|
||||
|
||||
# Helper functions
|
||||
async def monitor_settlement_completion(message_id: str, job_id: str):
|
||||
"""Background task to monitor settlement completion"""
|
||||
settlement_hook = get_settlement_hook()
|
||||
|
||||
# Monitor for up to 1 hour
|
||||
max_wait = 3600
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while asyncio.get_event_loop().time() - start_time < max_wait:
|
||||
result = await settlement_hook.get_settlement_status(message_id)
|
||||
|
||||
# Update job status
|
||||
job = await Job.get(job_id)
|
||||
if job:
|
||||
job.cross_chain_settlement_status = result.status.value
|
||||
await job.save()
|
||||
|
||||
# If completed or failed, stop monitoring
|
||||
if result.status.value in ['completed', 'failed']:
|
||||
break
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
||||
def estimate_completion_time(status) -> Optional[str]:
|
||||
"""Estimate completion time based on status"""
|
||||
if status.value == 'completed':
|
||||
return None
|
||||
elif status.value == 'pending':
|
||||
return "5-10 minutes"
|
||||
elif status.value == 'in_progress':
|
||||
return "2-5 minutes"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def get_bridge_from_tx(tx_hash: str) -> str:
|
||||
"""Get bridge name from transaction hash"""
|
||||
# This would look up the bridge from the transaction
|
||||
# For now, return placeholder
|
||||
return "layerzero"
|
||||
|
||||
|
||||
async def get_bridge_from_job(job_id: str) -> str:
|
||||
"""Get bridge name from job"""
|
||||
# This would look up the bridge from the job
|
||||
# For now, return placeholder
|
||||
return "layerzero"
|
||||
|
||||
|
||||
async def get_job_id_from_settlement(message_id: str) -> Optional[str]:
|
||||
"""Get job ID from settlement message ID"""
|
||||
# This would look up the job ID from storage
|
||||
# For now, return None
|
||||
return None
|
||||
21
apps/coordinator-api/aitbc/settlement/__init__.py
Normal file
21
apps/coordinator-api/aitbc/settlement/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Cross-chain settlement module for AITBC
|
||||
"""
|
||||
|
||||
from .manager import BridgeManager
|
||||
from .hooks import SettlementHook, BatchSettlementHook, SettlementMonitor
|
||||
from .storage import SettlementStorage, InMemorySettlementStorage
|
||||
from .bridges.base import BridgeAdapter, BridgeConfig, SettlementMessage, SettlementResult
|
||||
|
||||
__all__ = [
|
||||
"BridgeManager",
|
||||
"SettlementHook",
|
||||
"BatchSettlementHook",
|
||||
"SettlementMonitor",
|
||||
"SettlementStorage",
|
||||
"InMemorySettlementStorage",
|
||||
"BridgeAdapter",
|
||||
"BridgeConfig",
|
||||
"SettlementMessage",
|
||||
"SettlementResult",
|
||||
]
|
||||
23
apps/coordinator-api/aitbc/settlement/bridges/__init__.py
Normal file
23
apps/coordinator-api/aitbc/settlement/bridges/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""
|
||||
Bridge adapters for cross-chain settlements
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
BridgeAdapter,
|
||||
BridgeConfig,
|
||||
SettlementMessage,
|
||||
SettlementResult,
|
||||
BridgeStatus,
|
||||
BridgeError
|
||||
)
|
||||
from .layerzero import LayerZeroAdapter
|
||||
|
||||
__all__ = [
|
||||
"BridgeAdapter",
|
||||
"BridgeConfig",
|
||||
"SettlementMessage",
|
||||
"SettlementResult",
|
||||
"BridgeStatus",
|
||||
"BridgeError",
|
||||
"LayerZeroAdapter",
|
||||
]
|
||||
172
apps/coordinator-api/aitbc/settlement/bridges/base.py
Normal file
172
apps/coordinator-api/aitbc/settlement/bridges/base.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""
|
||||
Base interfaces for cross-chain settlement bridges
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class BridgeStatus(Enum):
|
||||
"""Bridge operation status"""
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
REFUNDED = "refunded"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BridgeConfig:
|
||||
"""Bridge configuration"""
|
||||
name: str
|
||||
enabled: bool
|
||||
endpoint_address: str
|
||||
supported_chains: List[int]
|
||||
default_fee: str
|
||||
max_message_size: int
|
||||
timeout: int = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class SettlementMessage:
|
||||
"""Message to be settled across chains"""
|
||||
source_chain_id: int
|
||||
target_chain_id: int
|
||||
job_id: str
|
||||
receipt_hash: str
|
||||
proof_data: Dict[str, Any]
|
||||
payment_amount: int
|
||||
payment_token: str
|
||||
nonce: int
|
||||
signature: str
|
||||
gas_limit: Optional[int] = None
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SettlementResult:
|
||||
"""Result of settlement operation"""
|
||||
message_id: str
|
||||
status: BridgeStatus
|
||||
transaction_hash: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
gas_used: Optional[int] = None
|
||||
fee_paid: Optional[int] = None
|
||||
created_at: datetime = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class BridgeAdapter(ABC):
|
||||
"""Abstract interface for bridge adapters"""
|
||||
|
||||
def __init__(self, config: BridgeConfig):
|
||||
self.config = config
|
||||
self.name = config.name
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the bridge adapter"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(self, message: SettlementMessage) -> SettlementResult:
|
||||
"""Send message to target chain"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def verify_delivery(self, message_id: str) -> bool:
|
||||
"""Verify message was delivered"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_message_status(self, message_id: str) -> SettlementResult:
|
||||
"""Get current status of message"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def estimate_cost(self, message: SettlementMessage) -> Dict[str, int]:
|
||||
"""Estimate bridge fees"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def refund_failed_message(self, message_id: str) -> SettlementResult:
|
||||
"""Refund failed message if supported"""
|
||||
pass
|
||||
|
||||
def get_supported_chains(self) -> List[int]:
|
||||
"""Get list of supported target chains"""
|
||||
return self.config.supported_chains
|
||||
|
||||
def get_max_message_size(self) -> int:
|
||||
"""Get maximum message size in bytes"""
|
||||
return self.config.max_message_size
|
||||
|
||||
async def validate_message(self, message: SettlementMessage) -> bool:
|
||||
"""Validate message before sending"""
|
||||
# Check if target chain is supported
|
||||
if message.target_chain_id not in self.get_supported_chains():
|
||||
raise ValueError(f"Chain {message.target_chain_id} not supported")
|
||||
|
||||
# Check message size
|
||||
message_size = len(json.dumps(message.proof_data).encode())
|
||||
if message_size > self.get_max_message_size():
|
||||
raise ValueError(f"Message too large: {message_size} > {self.get_max_message_size()}")
|
||||
|
||||
# Validate signature
|
||||
if not await self._verify_signature(message):
|
||||
raise ValueError("Invalid signature")
|
||||
|
||||
return True
|
||||
|
||||
async def _verify_signature(self, message: SettlementMessage) -> bool:
|
||||
"""Verify message signature - to be implemented by subclass"""
|
||||
# This would verify the cryptographic signature
|
||||
# Implementation depends on the signature scheme used
|
||||
return True
|
||||
|
||||
def _encode_payload(self, message: SettlementMessage) -> bytes:
|
||||
"""Encode message payload - to be implemented by subclass"""
|
||||
# Each bridge may have different encoding requirements
|
||||
raise NotImplementedError("Subclass must implement _encode_payload")
|
||||
|
||||
async def _get_gas_estimate(self, message: SettlementMessage) -> int:
|
||||
"""Get gas estimate for message - to be implemented by subclass"""
|
||||
# Each bridge has different gas requirements
|
||||
raise NotImplementedError("Subclass must implement _get_gas_estimate")
|
||||
|
||||
|
||||
class BridgeError(Exception):
|
||||
"""Base exception for bridge errors"""
|
||||
pass
|
||||
|
||||
|
||||
class BridgeNotSupportedError(BridgeError):
|
||||
"""Raised when operation is not supported by bridge"""
|
||||
pass
|
||||
|
||||
|
||||
class BridgeTimeoutError(BridgeError):
|
||||
"""Raised when bridge operation times out"""
|
||||
pass
|
||||
|
||||
|
||||
class BridgeInsufficientFundsError(BridgeError):
|
||||
"""Raised when insufficient funds for bridge operation"""
|
||||
pass
|
||||
|
||||
|
||||
class BridgeMessageTooLargeError(BridgeError):
|
||||
"""Raised when message exceeds bridge limits"""
|
||||
pass
|
||||
288
apps/coordinator-api/aitbc/settlement/bridges/layerzero.py
Normal file
288
apps/coordinator-api/aitbc/settlement/bridges/layerzero.py
Normal file
@ -0,0 +1,288 @@
|
||||
"""
|
||||
LayerZero bridge adapter implementation
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
import json
|
||||
import asyncio
|
||||
from web3 import Web3
|
||||
from web3.contract import Contract
|
||||
from eth_utils import to_checksum_address, encode_hex
|
||||
|
||||
from .base import (
|
||||
BridgeAdapter,
|
||||
BridgeConfig,
|
||||
SettlementMessage,
|
||||
SettlementResult,
|
||||
BridgeStatus,
|
||||
BridgeError,
|
||||
BridgeTimeoutError,
|
||||
BridgeInsufficientFundsError
|
||||
)
|
||||
|
||||
|
||||
class LayerZeroAdapter(BridgeAdapter):
|
||||
"""LayerZero bridge adapter for cross-chain settlements"""
|
||||
|
||||
# LayerZero chain IDs
|
||||
CHAIN_IDS = {
|
||||
1: 101, # Ethereum
|
||||
137: 109, # Polygon
|
||||
56: 102, # BSC
|
||||
42161: 110, # Arbitrum
|
||||
10: 111, # Optimism
|
||||
43114: 106 # Avalanche
|
||||
}
|
||||
|
||||
def __init__(self, config: BridgeConfig, web3: Web3):
|
||||
super().__init__(config)
|
||||
self.web3 = web3
|
||||
self.endpoint: Optional[Contract] = None
|
||||
self.ultra_light_node: Optional[Contract] = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize LayerZero contracts"""
|
||||
# Load LayerZero endpoint ABI
|
||||
endpoint_abi = await self._load_abi("LayerZeroEndpoint")
|
||||
self.endpoint = self.web3.eth.contract(
|
||||
address=to_checksum_address(self.config.endpoint_address),
|
||||
abi=endpoint_abi
|
||||
)
|
||||
|
||||
# Load Ultra Light Node ABI for fee estimation
|
||||
uln_abi = await self._load_abi("UltraLightNode")
|
||||
uln_address = await self.endpoint.functions.ultraLightNode().call()
|
||||
self.ultra_light_node = self.web3.eth.contract(
|
||||
address=to_checksum_address(uln_address),
|
||||
abi=uln_abi
|
||||
)
|
||||
|
||||
async def send_message(self, message: SettlementMessage) -> SettlementResult:
|
||||
"""Send message via LayerZero"""
|
||||
try:
|
||||
# Validate message
|
||||
await self.validate_message(message)
|
||||
|
||||
# Get target address on destination chain
|
||||
target_address = await self._get_target_address(message.target_chain_id)
|
||||
|
||||
# Encode payload
|
||||
payload = self._encode_payload(message)
|
||||
|
||||
# Estimate fees
|
||||
fees = await self.estimate_cost(message)
|
||||
|
||||
# Get gas limit
|
||||
gas_limit = message.gas_limit or await self._get_gas_estimate(message)
|
||||
|
||||
# Build transaction
|
||||
tx_params = {
|
||||
'from': await self._get_signer_address(),
|
||||
'gas': gas_limit,
|
||||
'value': fees['layerZeroFee'],
|
||||
'nonce': await self.web3.eth.get_transaction_count(
|
||||
await self._get_signer_address()
|
||||
)
|
||||
}
|
||||
|
||||
# Send transaction
|
||||
tx_hash = await self.endpoint.functions.send(
|
||||
self.CHAIN_IDS[message.target_chain_id], # dstChainId
|
||||
target_address, # destination address
|
||||
payload, # payload
|
||||
message.payment_amount, # value (optional)
|
||||
[0, 0, 0], # address and parameters for adapterParams
|
||||
message.nonce # refund address
|
||||
).transact(tx_params)
|
||||
|
||||
# Wait for confirmation
|
||||
receipt = await self.web3.eth.wait_for_transaction_receipt(tx_hash)
|
||||
|
||||
return SettlementResult(
|
||||
message_id=tx_hash.hex(),
|
||||
status=BridgeStatus.IN_PROGRESS,
|
||||
transaction_hash=tx_hash.hex(),
|
||||
gas_used=receipt.gasUsed,
|
||||
fee_paid=fees['layerZeroFee']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return SettlementResult(
|
||||
message_id="",
|
||||
status=BridgeStatus.FAILED,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
async def verify_delivery(self, message_id: str) -> bool:
|
||||
"""Verify message was delivered"""
|
||||
try:
|
||||
# Get transaction receipt
|
||||
receipt = await self.web3.eth.get_transaction_receipt(message_id)
|
||||
|
||||
# Check for Delivered event
|
||||
delivered_logs = self.endpoint.events.Delivered().processReceipt(receipt)
|
||||
return len(delivered_logs) > 0
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_message_status(self, message_id: str) -> SettlementResult:
|
||||
"""Get current status of message"""
|
||||
try:
|
||||
# Get transaction receipt
|
||||
receipt = await self.web3.eth.get_transaction_receipt(message_id)
|
||||
|
||||
if receipt.status == 0:
|
||||
return SettlementResult(
|
||||
message_id=message_id,
|
||||
status=BridgeStatus.FAILED,
|
||||
transaction_hash=message_id,
|
||||
completed_at=receipt['blockTimestamp']
|
||||
)
|
||||
|
||||
# Check if delivered
|
||||
if await self.verify_delivery(message_id):
|
||||
return SettlementResult(
|
||||
message_id=message_id,
|
||||
status=BridgeStatus.COMPLETED,
|
||||
transaction_hash=message_id,
|
||||
completed_at=receipt['blockTimestamp']
|
||||
)
|
||||
|
||||
# Still in progress
|
||||
return SettlementResult(
|
||||
message_id=message_id,
|
||||
status=BridgeStatus.IN_PROGRESS,
|
||||
transaction_hash=message_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return SettlementResult(
|
||||
message_id=message_id,
|
||||
status=BridgeStatus.FAILED,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
async def estimate_cost(self, message: SettlementMessage) -> Dict[str, int]:
|
||||
"""Estimate LayerZero fees"""
|
||||
try:
|
||||
# Get destination chain ID
|
||||
dst_chain_id = self.CHAIN_IDS[message.target_chain_id]
|
||||
|
||||
# Get target address
|
||||
target_address = await self._get_target_address(message.target_chain_id)
|
||||
|
||||
# Encode payload
|
||||
payload = self._encode_payload(message)
|
||||
|
||||
# Estimate fee using LayerZero endpoint
|
||||
(native_fee, zro_fee) = await self.endpoint.functions.estimateFees(
|
||||
dst_chain_id,
|
||||
target_address,
|
||||
payload,
|
||||
False, # payInZRO
|
||||
[0, 0, 0] # adapterParams
|
||||
).call()
|
||||
|
||||
return {
|
||||
'layerZeroFee': native_fee,
|
||||
'zroFee': zro_fee,
|
||||
'total': native_fee + zro_fee
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise BridgeError(f"Failed to estimate fees: {str(e)}")
|
||||
|
||||
async def refund_failed_message(self, message_id: str) -> SettlementResult:
|
||||
"""LayerZero doesn't support direct refunds"""
|
||||
raise BridgeNotSupportedError("LayerZero does not support message refunds")
|
||||
|
||||
def _encode_payload(self, message: SettlementMessage) -> bytes:
|
||||
"""Encode settlement message for LayerZero"""
|
||||
# Use ABI encoding for structured data
|
||||
from web3 import Web3
|
||||
|
||||
# Define the payload structure
|
||||
payload_types = [
|
||||
'uint256', # job_id
|
||||
'bytes32', # receipt_hash
|
||||
'bytes', # proof_data (JSON)
|
||||
'uint256', # payment_amount
|
||||
'address', # payment_token
|
||||
'uint256', # nonce
|
||||
'bytes' # signature
|
||||
]
|
||||
|
||||
payload_values = [
|
||||
int(message.job_id),
|
||||
bytes.fromhex(message.receipt_hash),
|
||||
json.dumps(message.proof_data).encode(),
|
||||
message.payment_amount,
|
||||
to_checksum_address(message.payment_token),
|
||||
message.nonce,
|
||||
bytes.fromhex(message.signature)
|
||||
]
|
||||
|
||||
# Encode the payload
|
||||
encoded = Web3().codec.encode(payload_types, payload_values)
|
||||
return encoded
|
||||
|
||||
async def _get_target_address(self, target_chain_id: int) -> str:
|
||||
"""Get target contract address on destination chain"""
|
||||
# This would look up the target address from configuration
|
||||
# For now, return a placeholder
|
||||
target_addresses = {
|
||||
1: "0x...", # Ethereum
|
||||
137: "0x...", # Polygon
|
||||
56: "0x...", # BSC
|
||||
42161: "0x..." # Arbitrum
|
||||
}
|
||||
|
||||
if target_chain_id not in target_addresses:
|
||||
raise ValueError(f"No target address configured for chain {target_chain_id}")
|
||||
|
||||
return target_addresses[target_chain_id]
|
||||
|
||||
async def _get_gas_estimate(self, message: SettlementMessage) -> int:
|
||||
"""Estimate gas for LayerZero transaction"""
|
||||
try:
|
||||
# Get target address
|
||||
target_address = await self._get_target_address(message.target_chain_id)
|
||||
|
||||
# Encode payload
|
||||
payload = self._encode_payload(message)
|
||||
|
||||
# Estimate gas
|
||||
gas_estimate = await self.endpoint.functions.send(
|
||||
self.CHAIN_IDS[message.target_chain_id],
|
||||
target_address,
|
||||
payload,
|
||||
message.payment_amount,
|
||||
[0, 0, 0],
|
||||
message.nonce
|
||||
).estimateGas({'from': await self._get_signer_address()})
|
||||
|
||||
# Add 20% buffer
|
||||
return int(gas_estimate * 1.2)
|
||||
|
||||
except Exception:
|
||||
# Return default estimate
|
||||
return 300000
|
||||
|
||||
async def _get_signer_address(self) -> str:
|
||||
"""Get the signer address for transactions"""
|
||||
# This would get the address from the wallet/key management system
|
||||
# For now, return a placeholder
|
||||
return "0x..."
|
||||
|
||||
async def _load_abi(self, contract_name: str) -> List[Dict]:
|
||||
"""Load contract ABI from file or registry"""
|
||||
# This would load the ABI from a file or contract registry
|
||||
# For now, return empty list
|
||||
return []
|
||||
|
||||
async def _verify_signature(self, message: SettlementMessage) -> bool:
|
||||
"""Verify LayerZero message signature"""
|
||||
# Implement signature verification specific to LayerZero
|
||||
# This would verify the message signature using the appropriate scheme
|
||||
return True
|
||||
327
apps/coordinator-api/aitbc/settlement/hooks.py
Normal file
327
apps/coordinator-api/aitbc/settlement/hooks.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""
|
||||
Settlement hooks for coordinator API integration
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from .manager import BridgeManager
|
||||
from .bridges.base import (
|
||||
SettlementMessage,
|
||||
SettlementResult,
|
||||
BridgeStatus
|
||||
)
|
||||
from ..models.job import Job
|
||||
from ..models.receipt import Receipt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SettlementHook:
|
||||
"""Settlement hook for coordinator to handle cross-chain settlements"""
|
||||
|
||||
def __init__(self, bridge_manager: BridgeManager):
|
||||
self.bridge_manager = bridge_manager
|
||||
self._enabled = True
|
||||
|
||||
async def on_job_completed(self, job: Job) -> None:
|
||||
"""Called when a job completes successfully"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Check if cross-chain settlement is required
|
||||
if await self._requires_cross_chain_settlement(job):
|
||||
await self._initiate_settlement(job)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle job completion for {job.id}: {e}")
|
||||
# Don't fail the job, just log the error
|
||||
await self._handle_settlement_error(job, e)
|
||||
|
||||
async def on_job_failed(self, job: Job, error: Exception) -> None:
|
||||
"""Called when a job fails"""
|
||||
# For failed jobs, we might want to refund any cross-chain payments
|
||||
if job.cross_chain_payment_id:
|
||||
try:
|
||||
await self._refund_cross_chain_payment(job)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refund cross-chain payment for {job.id}: {e}")
|
||||
|
||||
async def initiate_manual_settlement(
|
||||
self,
|
||||
job_id: str,
|
||||
target_chain_id: int,
|
||||
bridge_name: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
) -> SettlementResult:
|
||||
"""Manually initiate cross-chain settlement for a job"""
|
||||
# Get job
|
||||
job = await Job.get(job_id)
|
||||
if not job:
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
if not job.completed:
|
||||
raise ValueError(f"Job {job_id} is not completed")
|
||||
|
||||
# Override target chain if specified
|
||||
if target_chain_id:
|
||||
job.target_chain = target_chain_id
|
||||
|
||||
# Create settlement message
|
||||
message = await self._create_settlement_message(job, options)
|
||||
|
||||
# Send settlement
|
||||
result = await self.bridge_manager.settle_cross_chain(
|
||||
message,
|
||||
bridge_name=bridge_name
|
||||
)
|
||||
|
||||
# Update job with settlement info
|
||||
job.cross_chain_settlement_id = result.message_id
|
||||
job.cross_chain_bridge = bridge_name or self.bridge_manager.default_adapter
|
||||
await job.save()
|
||||
|
||||
return result
|
||||
|
||||
async def get_settlement_status(self, settlement_id: str) -> SettlementResult:
|
||||
"""Get status of a cross-chain settlement"""
|
||||
return await self.bridge_manager.get_settlement_status(settlement_id)
|
||||
|
||||
async def estimate_settlement_cost(
|
||||
self,
|
||||
job_id: str,
|
||||
target_chain_id: int,
|
||||
bridge_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for cross-chain settlement"""
|
||||
# Get job
|
||||
job = await Job.get(job_id)
|
||||
if not job:
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
# Create mock settlement message for estimation
|
||||
message = SettlementMessage(
|
||||
source_chain_id=await self._get_current_chain_id(),
|
||||
target_chain_id=target_chain_id,
|
||||
job_id=job.id,
|
||||
receipt_hash=job.receipt.hash if job.receipt else "",
|
||||
proof_data=job.receipt.proof if job.receipt else {},
|
||||
payment_amount=job.payment_amount or 0,
|
||||
payment_token=job.payment_token or "AITBC",
|
||||
nonce=await self._generate_nonce(),
|
||||
signature="" # Not needed for estimation
|
||||
)
|
||||
|
||||
return await self.bridge_manager.estimate_settlement_cost(
|
||||
message,
|
||||
bridge_name=bridge_name
|
||||
)
|
||||
|
||||
async def list_supported_bridges(self) -> Dict[str, Any]:
|
||||
"""List all supported bridges and their capabilities"""
|
||||
return self.bridge_manager.get_bridge_info()
|
||||
|
||||
async def list_supported_chains(self) -> Dict[str, List[int]]:
|
||||
"""List all supported chains by bridge"""
|
||||
return self.bridge_manager.get_supported_chains()
|
||||
|
||||
async def enable(self) -> None:
|
||||
"""Enable settlement hooks"""
|
||||
self._enabled = True
|
||||
logger.info("Settlement hooks enabled")
|
||||
|
||||
async def disable(self) -> None:
|
||||
"""Disable settlement hooks"""
|
||||
self._enabled = False
|
||||
logger.info("Settlement hooks disabled")
|
||||
|
||||
async def _requires_cross_chain_settlement(self, job: Job) -> bool:
|
||||
"""Check if job requires cross-chain settlement"""
|
||||
# Check if job has target chain different from current
|
||||
if job.target_chain and job.target_chain != await self._get_current_chain_id():
|
||||
return True
|
||||
|
||||
# Check if job explicitly requests cross-chain settlement
|
||||
if job.requires_cross_chain_settlement:
|
||||
return True
|
||||
|
||||
# Check if payment is on different chain
|
||||
if job.payment_chain and job.payment_chain != await self._get_current_chain_id():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _initiate_settlement(self, job: Job) -> None:
|
||||
"""Initiate cross-chain settlement for a job"""
|
||||
try:
|
||||
# Create settlement message
|
||||
message = await self._create_settlement_message(job)
|
||||
|
||||
# Get optimal bridge if not specified
|
||||
bridge_name = job.preferred_bridge or await self.bridge_manager.get_optimal_bridge(
|
||||
message,
|
||||
priority=job.settlement_priority or 'cost'
|
||||
)
|
||||
|
||||
# Send settlement
|
||||
result = await self.bridge_manager.settle_cross_chain(
|
||||
message,
|
||||
bridge_name=bridge_name
|
||||
)
|
||||
|
||||
# Update job with settlement info
|
||||
job.cross_chain_settlement_id = result.message_id
|
||||
job.cross_chain_bridge = bridge_name
|
||||
job.cross_chain_settlement_status = result.status.value
|
||||
await job.save()
|
||||
|
||||
logger.info(f"Initiated cross-chain settlement for job {job.id}: {result.message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate settlement for job {job.id}: {e}")
|
||||
await self._handle_settlement_error(job, e)
|
||||
|
||||
async def _create_settlement_message(self, job: Job, options: Optional[Dict[str, Any]] = None) -> SettlementMessage:
|
||||
"""Create settlement message from job"""
|
||||
# Get current chain ID
|
||||
source_chain_id = await self._get_current_chain_id()
|
||||
|
||||
# Get receipt data
|
||||
receipt_hash = ""
|
||||
proof_data = {}
|
||||
zk_proof = None
|
||||
|
||||
if job.receipt:
|
||||
receipt_hash = job.receipt.hash
|
||||
proof_data = job.receipt.proof or {}
|
||||
|
||||
# Check if ZK proof is included in receipt
|
||||
if options and options.get("use_zk_proof"):
|
||||
zk_proof = job.receipt.payload.get("zk_proof")
|
||||
if not zk_proof:
|
||||
logger.warning(f"ZK proof requested but not found in receipt for job {job.id}")
|
||||
|
||||
# Sign the settlement message
|
||||
signature = await self._sign_settlement_message(job)
|
||||
|
||||
return SettlementMessage(
|
||||
source_chain_id=source_chain_id,
|
||||
target_chain_id=job.target_chain or source_chain_id,
|
||||
job_id=job.id,
|
||||
receipt_hash=receipt_hash,
|
||||
proof_data=proof_data,
|
||||
zk_proof=zk_proof,
|
||||
payment_amount=job.payment_amount or 0,
|
||||
payment_token=job.payment_token or "AITBC",
|
||||
nonce=await self._generate_nonce(),
|
||||
signature=signature,
|
||||
gas_limit=job.settlement_gas_limit,
|
||||
privacy_level=options.get("privacy_level") if options else None
|
||||
)
|
||||
|
||||
async def _get_current_chain_id(self) -> int:
|
||||
"""Get the current blockchain chain ID"""
|
||||
# This would get the chain ID from the blockchain node
|
||||
# For now, return a placeholder
|
||||
return 1 # Ethereum mainnet
|
||||
|
||||
async def _generate_nonce(self) -> int:
|
||||
"""Generate a unique nonce for settlement"""
|
||||
# This would generate a unique nonce
|
||||
# For now, use timestamp
|
||||
return int(datetime.utcnow().timestamp())
|
||||
|
||||
async def _sign_settlement_message(self, job: Job) -> str:
|
||||
"""Sign the settlement message"""
|
||||
# This would sign the message with the appropriate key
|
||||
# For now, return a placeholder
|
||||
return "0x..." * 20
|
||||
|
||||
async def _handle_settlement_error(self, job: Job, error: Exception) -> None:
|
||||
"""Handle settlement errors"""
|
||||
# Update job with error info
|
||||
job.cross_chain_settlement_error = str(error)
|
||||
job.cross_chain_settlement_status = BridgeStatus.FAILED.value
|
||||
await job.save()
|
||||
|
||||
# Notify monitoring system
|
||||
await self._notify_settlement_failure(job, error)
|
||||
|
||||
async def _refund_cross_chain_payment(self, job: Job) -> None:
|
||||
"""Refund a cross-chain payment if possible"""
|
||||
if not job.cross_chain_payment_id:
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self.bridge_manager.refund_failed_settlement(
|
||||
job.cross_chain_payment_id
|
||||
)
|
||||
|
||||
# Update job with refund info
|
||||
job.cross_chain_refund_id = result.message_id
|
||||
job.cross_chain_refund_status = result.status.value
|
||||
await job.save()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refund cross-chain payment for {job.id}: {e}")
|
||||
|
||||
async def _notify_settlement_failure(self, job: Job, error: Exception) -> None:
|
||||
"""Notify monitoring system of settlement failure"""
|
||||
# This would send alerts to the monitoring system
|
||||
logger.error(f"Settlement failure for job {job.id}: {error}")
|
||||
|
||||
|
||||
class BatchSettlementHook:
|
||||
"""Hook for handling batch settlements"""
|
||||
|
||||
def __init__(self, bridge_manager: BridgeManager):
|
||||
self.bridge_manager = bridge_manager
|
||||
self.batch_size = 10
|
||||
self.batch_timeout = 300 # 5 minutes
|
||||
|
||||
async def add_to_batch(self, job: Job) -> None:
|
||||
"""Add job to batch settlement queue"""
|
||||
# This would add the job to a batch queue
|
||||
pass
|
||||
|
||||
async def process_batch(self) -> List[SettlementResult]:
|
||||
"""Process a batch of settlements"""
|
||||
# This would process queued jobs in batches
|
||||
# For now, return empty list
|
||||
return []
|
||||
|
||||
|
||||
class SettlementMonitor:
|
||||
"""Monitor for cross-chain settlements"""
|
||||
|
||||
def __init__(self, bridge_manager: BridgeManager):
|
||||
self.bridge_manager = bridge_manager
|
||||
self._monitoring = False
|
||||
|
||||
async def start_monitoring(self) -> None:
|
||||
"""Start monitoring settlements"""
|
||||
self._monitoring = True
|
||||
|
||||
while self._monitoring:
|
||||
try:
|
||||
# Get pending settlements
|
||||
pending = await self.bridge_manager.storage.get_pending_settlements()
|
||||
|
||||
# Check status of each
|
||||
for settlement in pending:
|
||||
await self.bridge_manager.get_settlement_status(
|
||||
settlement['message_id']
|
||||
)
|
||||
|
||||
# Wait before next check
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in settlement monitoring: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def stop_monitoring(self) -> None:
|
||||
"""Stop monitoring settlements"""
|
||||
self._monitoring = False
|
||||
337
apps/coordinator-api/aitbc/settlement/manager.py
Normal file
337
apps/coordinator-api/aitbc/settlement/manager.py
Normal file
@ -0,0 +1,337 @@
|
||||
"""
|
||||
Bridge manager for cross-chain settlements
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Type
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import asdict
|
||||
|
||||
from .bridges.base import (
|
||||
BridgeAdapter,
|
||||
BridgeConfig,
|
||||
SettlementMessage,
|
||||
SettlementResult,
|
||||
BridgeStatus,
|
||||
BridgeError
|
||||
)
|
||||
from .bridges.layerzero import LayerZeroAdapter
|
||||
from .storage import SettlementStorage
|
||||
|
||||
|
||||
class BridgeManager:
|
||||
"""Manages multiple bridge adapters for cross-chain settlements"""
|
||||
|
||||
def __init__(self, storage: SettlementStorage):
|
||||
self.adapters: Dict[str, BridgeAdapter] = {}
|
||||
self.default_adapter: Optional[str] = None
|
||||
self.storage = storage
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self, configs: Dict[str, BridgeConfig]) -> None:
|
||||
"""Initialize all bridge adapters"""
|
||||
for name, config in configs.items():
|
||||
if config.enabled:
|
||||
adapter = await self._create_adapter(config)
|
||||
await adapter.initialize()
|
||||
self.adapters[name] = adapter
|
||||
|
||||
# Set first enabled adapter as default
|
||||
if self.default_adapter is None:
|
||||
self.default_adapter = name
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def register_adapter(self, name: str, adapter: BridgeAdapter) -> None:
|
||||
"""Register a bridge adapter"""
|
||||
await adapter.initialize()
|
||||
self.adapters[name] = adapter
|
||||
|
||||
if self.default_adapter is None:
|
||||
self.default_adapter = name
|
||||
|
||||
async def settle_cross_chain(
|
||||
self,
|
||||
message: SettlementMessage,
|
||||
bridge_name: Optional[str] = None,
|
||||
retry_on_failure: bool = True
|
||||
) -> SettlementResult:
|
||||
"""Settle message across chains"""
|
||||
if not self._initialized:
|
||||
raise BridgeError("Bridge manager not initialized")
|
||||
|
||||
# Get adapter
|
||||
adapter = self._get_adapter(bridge_name)
|
||||
|
||||
# Validate message
|
||||
await adapter.validate_message(message)
|
||||
|
||||
# Store initial settlement record
|
||||
await self.storage.store_settlement(
|
||||
message_id="pending",
|
||||
message=message,
|
||||
bridge_name=adapter.name,
|
||||
status=BridgeStatus.PENDING
|
||||
)
|
||||
|
||||
# Attempt settlement with retries
|
||||
max_retries = 3 if retry_on_failure else 1
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Send message
|
||||
result = await adapter.send_message(message)
|
||||
|
||||
# Update storage with result
|
||||
await self.storage.update_settlement(
|
||||
message_id=result.message_id,
|
||||
status=result.status,
|
||||
transaction_hash=result.transaction_hash,
|
||||
error_message=result.error_message
|
||||
)
|
||||
|
||||
# Start monitoring for completion
|
||||
asyncio.create_task(self._monitor_settlement(result.message_id))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
# Wait before retry
|
||||
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
||||
continue
|
||||
else:
|
||||
# Final attempt failed
|
||||
result = SettlementResult(
|
||||
message_id="",
|
||||
status=BridgeStatus.FAILED,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
await self.storage.update_settlement(
|
||||
message_id="",
|
||||
status=BridgeStatus.FAILED,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_settlement_status(self, message_id: str) -> SettlementResult:
|
||||
"""Get current status of settlement"""
|
||||
# Get from storage first
|
||||
stored = await self.storage.get_settlement(message_id)
|
||||
|
||||
if not stored:
|
||||
raise ValueError(f"Settlement {message_id} not found")
|
||||
|
||||
# If completed or failed, return stored result
|
||||
if stored['status'] in [BridgeStatus.COMPLETED, BridgeStatus.FAILED]:
|
||||
return SettlementResult(**stored)
|
||||
|
||||
# Otherwise check with bridge
|
||||
adapter = self.adapters.get(stored['bridge_name'])
|
||||
if not adapter:
|
||||
raise BridgeError(f"Bridge {stored['bridge_name']} not found")
|
||||
|
||||
# Get current status from bridge
|
||||
result = await adapter.get_message_status(message_id)
|
||||
|
||||
# Update storage if status changed
|
||||
if result.status != stored['status']:
|
||||
await self.storage.update_settlement(
|
||||
message_id=message_id,
|
||||
status=result.status,
|
||||
completed_at=result.completed_at
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def estimate_settlement_cost(
|
||||
self,
|
||||
message: SettlementMessage,
|
||||
bridge_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for settlement across different bridges"""
|
||||
results = {}
|
||||
|
||||
if bridge_name:
|
||||
# Estimate for specific bridge
|
||||
adapter = self._get_adapter(bridge_name)
|
||||
results[bridge_name] = await adapter.estimate_cost(message)
|
||||
else:
|
||||
# Estimate for all bridges
|
||||
for name, adapter in self.adapters.items():
|
||||
try:
|
||||
await adapter.validate_message(message)
|
||||
results[name] = await adapter.estimate_cost(message)
|
||||
except Exception as e:
|
||||
results[name] = {'error': str(e)}
|
||||
|
||||
return results
|
||||
|
||||
async def get_optimal_bridge(
|
||||
self,
|
||||
message: SettlementMessage,
|
||||
priority: str = 'cost' # 'cost' or 'speed'
|
||||
) -> str:
|
||||
"""Get optimal bridge for settlement"""
|
||||
if len(self.adapters) == 1:
|
||||
return list(self.adapters.keys())[0]
|
||||
|
||||
# Get estimates for all bridges
|
||||
estimates = await self.estimate_settlement_cost(message)
|
||||
|
||||
# Filter out failed estimates
|
||||
valid_estimates = {
|
||||
name: est for name, est in estimates.items()
|
||||
if 'error' not in est
|
||||
}
|
||||
|
||||
if not valid_estimates:
|
||||
raise BridgeError("No bridges available for settlement")
|
||||
|
||||
# Select based on priority
|
||||
if priority == 'cost':
|
||||
# Select cheapest
|
||||
optimal = min(valid_estimates.items(), key=lambda x: x[1]['total'])
|
||||
else:
|
||||
# Select fastest (based on historical data)
|
||||
# For now, return default
|
||||
optimal = (self.default_adapter, valid_estimates[self.default_adapter])
|
||||
|
||||
return optimal[0]
|
||||
|
||||
async def batch_settle(
|
||||
self,
|
||||
messages: List[SettlementMessage],
|
||||
bridge_name: Optional[str] = None
|
||||
) -> List[SettlementResult]:
|
||||
"""Settle multiple messages"""
|
||||
results = []
|
||||
|
||||
# Process in parallel with rate limiting
|
||||
semaphore = asyncio.Semaphore(5) # Max 5 concurrent settlements
|
||||
|
||||
async def settle_single(message):
|
||||
async with semaphore:
|
||||
return await self.settle_cross_chain(message, bridge_name)
|
||||
|
||||
tasks = [settle_single(msg) for msg in messages]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Convert exceptions to failed results
|
||||
processed_results = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
processed_results.append(SettlementResult(
|
||||
message_id="",
|
||||
status=BridgeStatus.FAILED,
|
||||
error_message=str(result)
|
||||
))
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
async def refund_failed_settlement(self, message_id: str) -> SettlementResult:
|
||||
"""Attempt to refund a failed settlement"""
|
||||
# Get settlement details
|
||||
stored = await self.storage.get_settlement(message_id)
|
||||
|
||||
if not stored:
|
||||
raise ValueError(f"Settlement {message_id} not found")
|
||||
|
||||
# Check if it's actually failed
|
||||
if stored['status'] != BridgeStatus.FAILED:
|
||||
raise ValueError(f"Settlement {message_id} is not in failed state")
|
||||
|
||||
# Get adapter
|
||||
adapter = self.adapters.get(stored['bridge_name'])
|
||||
if not adapter:
|
||||
raise BridgeError(f"Bridge {stored['bridge_name']} not found")
|
||||
|
||||
# Attempt refund
|
||||
result = await adapter.refund_failed_message(message_id)
|
||||
|
||||
# Update storage
|
||||
await self.storage.update_settlement(
|
||||
message_id=message_id,
|
||||
status=result.status,
|
||||
error_message=result.error_message
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_supported_chains(self) -> Dict[str, List[int]]:
|
||||
"""Get all supported chains by bridge"""
|
||||
chains = {}
|
||||
for name, adapter in self.adapters.items():
|
||||
chains[name] = adapter.get_supported_chains()
|
||||
return chains
|
||||
|
||||
def get_bridge_info(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get information about all bridges"""
|
||||
info = {}
|
||||
for name, adapter in self.adapters.items():
|
||||
info[name] = {
|
||||
'name': adapter.name,
|
||||
'supported_chains': adapter.get_supported_chains(),
|
||||
'max_message_size': adapter.get_max_message_size(),
|
||||
'config': asdict(adapter.config)
|
||||
}
|
||||
return info
|
||||
|
||||
async def _monitor_settlement(self, message_id: str) -> None:
|
||||
"""Monitor settlement until completion"""
|
||||
max_wait_time = timedelta(hours=1)
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
while datetime.utcnow() - start_time < max_wait_time:
|
||||
# Check status
|
||||
result = await self.get_settlement_status(message_id)
|
||||
|
||||
# If completed or failed, stop monitoring
|
||||
if result.status in [BridgeStatus.COMPLETED, BridgeStatus.FAILED]:
|
||||
break
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
|
||||
# If still pending after timeout, mark as failed
|
||||
if result.status == BridgeStatus.IN_PROGRESS:
|
||||
await self.storage.update_settlement(
|
||||
message_id=message_id,
|
||||
status=BridgeStatus.FAILED,
|
||||
error_message="Settlement timed out"
|
||||
)
|
||||
|
||||
def _get_adapter(self, bridge_name: Optional[str] = None) -> BridgeAdapter:
|
||||
"""Get bridge adapter"""
|
||||
if bridge_name:
|
||||
if bridge_name not in self.adapters:
|
||||
raise BridgeError(f"Bridge {bridge_name} not found")
|
||||
return self.adapters[bridge_name]
|
||||
|
||||
if self.default_adapter is None:
|
||||
raise BridgeError("No default bridge configured")
|
||||
|
||||
return self.adapters[self.default_adapter]
|
||||
|
||||
async def _create_adapter(self, config: BridgeConfig) -> BridgeAdapter:
|
||||
"""Create adapter instance based on config"""
|
||||
# Import web3 here to avoid circular imports
|
||||
from web3 import Web3
|
||||
|
||||
# Get web3 instance (this would be injected or configured)
|
||||
web3 = Web3() # Placeholder
|
||||
|
||||
if config.name == "layerzero":
|
||||
return LayerZeroAdapter(config, web3)
|
||||
# Add other adapters as they're implemented
|
||||
# elif config.name == "chainlink_ccip":
|
||||
# return ChainlinkCCIPAdapter(config, web3)
|
||||
else:
|
||||
raise BridgeError(f"Unknown bridge type: {config.name}")
|
||||
367
apps/coordinator-api/aitbc/settlement/storage.py
Normal file
367
apps/coordinator-api/aitbc/settlement/storage.py
Normal file
@ -0,0 +1,367 @@
|
||||
"""
|
||||
Storage layer for cross-chain settlements
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
|
||||
from .bridges.base import (
|
||||
SettlementMessage,
|
||||
SettlementResult,
|
||||
BridgeStatus
|
||||
)
|
||||
|
||||
|
||||
class SettlementStorage:
|
||||
"""Storage interface for settlement data"""
|
||||
|
||||
def __init__(self, db_connection):
|
||||
self.db = db_connection
|
||||
|
||||
async def store_settlement(
|
||||
self,
|
||||
message_id: str,
|
||||
message: SettlementMessage,
|
||||
bridge_name: str,
|
||||
status: BridgeStatus
|
||||
) -> None:
|
||||
"""Store a new settlement record"""
|
||||
query = """
|
||||
INSERT INTO settlements (
|
||||
message_id, job_id, source_chain_id, target_chain_id,
|
||||
receipt_hash, proof_data, payment_amount, payment_token,
|
||||
nonce, signature, bridge_name, status, created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
|
||||
)
|
||||
"""
|
||||
|
||||
await self.db.execute(query, (
|
||||
message_id,
|
||||
message.job_id,
|
||||
message.source_chain_id,
|
||||
message.target_chain_id,
|
||||
message.receipt_hash,
|
||||
json.dumps(message.proof_data),
|
||||
message.payment_amount,
|
||||
message.payment_token,
|
||||
message.nonce,
|
||||
message.signature,
|
||||
bridge_name,
|
||||
status.value,
|
||||
message.created_at or datetime.utcnow()
|
||||
))
|
||||
|
||||
async def update_settlement(
|
||||
self,
|
||||
message_id: str,
|
||||
status: Optional[BridgeStatus] = None,
|
||||
transaction_hash: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
completed_at: Optional[datetime] = None
|
||||
) -> None:
|
||||
"""Update settlement record"""
|
||||
updates = []
|
||||
params = []
|
||||
param_count = 1
|
||||
|
||||
if status is not None:
|
||||
updates.append(f"status = ${param_count}")
|
||||
params.append(status.value)
|
||||
param_count += 1
|
||||
|
||||
if transaction_hash is not None:
|
||||
updates.append(f"transaction_hash = ${param_count}")
|
||||
params.append(transaction_hash)
|
||||
param_count += 1
|
||||
|
||||
if error_message is not None:
|
||||
updates.append(f"error_message = ${param_count}")
|
||||
params.append(error_message)
|
||||
param_count += 1
|
||||
|
||||
if completed_at is not None:
|
||||
updates.append(f"completed_at = ${param_count}")
|
||||
params.append(completed_at)
|
||||
param_count += 1
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
updates.append(f"updated_at = ${param_count}")
|
||||
params.append(datetime.utcnow())
|
||||
param_count += 1
|
||||
|
||||
params.append(message_id)
|
||||
|
||||
query = f"""
|
||||
UPDATE settlements
|
||||
SET {', '.join(updates)}
|
||||
WHERE message_id = ${param_count}
|
||||
"""
|
||||
|
||||
await self.db.execute(query, params)
|
||||
|
||||
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get settlement by message ID"""
|
||||
query = """
|
||||
SELECT * FROM settlements WHERE message_id = $1
|
||||
"""
|
||||
|
||||
result = await self.db.fetchrow(query, message_id)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
# Convert to dict
|
||||
settlement = dict(result)
|
||||
|
||||
# Parse JSON fields
|
||||
if settlement['proof_data']:
|
||||
settlement['proof_data'] = json.loads(settlement['proof_data'])
|
||||
|
||||
return settlement
|
||||
|
||||
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all settlements for a job"""
|
||||
query = """
|
||||
SELECT * FROM settlements
|
||||
WHERE job_id = $1
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
|
||||
results = await self.db.fetch(query, job_id)
|
||||
|
||||
settlements = []
|
||||
for result in results:
|
||||
settlement = dict(result)
|
||||
if settlement['proof_data']:
|
||||
settlement['proof_data'] = json.loads(settlement['proof_data'])
|
||||
settlements.append(settlement)
|
||||
|
||||
return settlements
|
||||
|
||||
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get all pending settlements"""
|
||||
query = """
|
||||
SELECT * FROM settlements
|
||||
WHERE status = 'pending' OR status = 'in_progress'
|
||||
"""
|
||||
params = []
|
||||
|
||||
if bridge_name:
|
||||
query += " AND bridge_name = $1"
|
||||
params.append(bridge_name)
|
||||
|
||||
query += " ORDER BY created_at ASC"
|
||||
|
||||
results = await self.db.fetch(query, *params)
|
||||
|
||||
settlements = []
|
||||
for result in results:
|
||||
settlement = dict(result)
|
||||
if settlement['proof_data']:
|
||||
settlement['proof_data'] = json.loads(settlement['proof_data'])
|
||||
settlements.append(settlement)
|
||||
|
||||
return settlements
|
||||
|
||||
async def get_settlement_stats(
|
||||
self,
|
||||
bridge_name: Optional[str] = None,
|
||||
time_range: Optional[int] = None # hours
|
||||
) -> Dict[str, Any]:
|
||||
"""Get settlement statistics"""
|
||||
conditions = []
|
||||
params = []
|
||||
param_count = 1
|
||||
|
||||
if bridge_name:
|
||||
conditions.append(f"bridge_name = ${param_count}")
|
||||
params.append(bridge_name)
|
||||
param_count += 1
|
||||
|
||||
if time_range:
|
||||
conditions.append(f"created_at > NOW() - INTERVAL '${param_count} hours'")
|
||||
params.append(time_range)
|
||||
param_count += 1
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
bridge_name,
|
||||
status,
|
||||
COUNT(*) as count,
|
||||
AVG(payment_amount) as avg_amount,
|
||||
SUM(payment_amount) as total_amount
|
||||
FROM settlements
|
||||
{where_clause}
|
||||
GROUP BY bridge_name, status
|
||||
"""
|
||||
|
||||
results = await self.db.fetch(query, *params)
|
||||
|
||||
stats = {}
|
||||
for result in results:
|
||||
bridge = result['bridge_name']
|
||||
if bridge not in stats:
|
||||
stats[bridge] = {}
|
||||
|
||||
stats[bridge][result['status']] = {
|
||||
'count': result['count'],
|
||||
'avg_amount': float(result['avg_amount']) if result['avg_amount'] else 0,
|
||||
'total_amount': float(result['total_amount']) if result['total_amount'] else 0
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
async def cleanup_old_settlements(self, days: int = 30) -> int:
|
||||
"""Clean up old completed settlements"""
|
||||
query = """
|
||||
DELETE FROM settlements
|
||||
WHERE status IN ('completed', 'failed')
|
||||
AND created_at < NOW() - INTERVAL $1 days
|
||||
"""
|
||||
|
||||
result = await self.db.execute(query, days)
|
||||
return result.split()[-1] # Return number of deleted rows
|
||||
|
||||
|
||||
# In-memory implementation for testing
|
||||
class InMemorySettlementStorage(SettlementStorage):
|
||||
"""In-memory storage implementation for testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.settlements: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def store_settlement(
|
||||
self,
|
||||
message_id: str,
|
||||
message: SettlementMessage,
|
||||
bridge_name: str,
|
||||
status: BridgeStatus
|
||||
) -> None:
|
||||
async with self._lock:
|
||||
self.settlements[message_id] = {
|
||||
'message_id': message_id,
|
||||
'job_id': message.job_id,
|
||||
'source_chain_id': message.source_chain_id,
|
||||
'target_chain_id': message.target_chain_id,
|
||||
'receipt_hash': message.receipt_hash,
|
||||
'proof_data': message.proof_data,
|
||||
'payment_amount': message.payment_amount,
|
||||
'payment_token': message.payment_token,
|
||||
'nonce': message.nonce,
|
||||
'signature': message.signature,
|
||||
'bridge_name': bridge_name,
|
||||
'status': status.value,
|
||||
'created_at': message.created_at or datetime.utcnow(),
|
||||
'updated_at': datetime.utcnow()
|
||||
}
|
||||
|
||||
async def update_settlement(
|
||||
self,
|
||||
message_id: str,
|
||||
status: Optional[BridgeStatus] = None,
|
||||
transaction_hash: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
completed_at: Optional[datetime] = None
|
||||
) -> None:
|
||||
async with self._lock:
|
||||
if message_id not in self.settlements:
|
||||
return
|
||||
|
||||
settlement = self.settlements[message_id]
|
||||
|
||||
if status is not None:
|
||||
settlement['status'] = status.value
|
||||
if transaction_hash is not None:
|
||||
settlement['transaction_hash'] = transaction_hash
|
||||
if error_message is not None:
|
||||
settlement['error_message'] = error_message
|
||||
if completed_at is not None:
|
||||
settlement['completed_at'] = completed_at
|
||||
|
||||
settlement['updated_at'] = datetime.utcnow()
|
||||
|
||||
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return self.settlements.get(message_id)
|
||||
|
||||
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return [
|
||||
s for s in self.settlements.values()
|
||||
if s['job_id'] == job_id
|
||||
]
|
||||
|
||||
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
pending = [
|
||||
s for s in self.settlements.values()
|
||||
if s['status'] in ['pending', 'in_progress']
|
||||
]
|
||||
|
||||
if bridge_name:
|
||||
pending = [s for s in pending if s['bridge_name'] == bridge_name]
|
||||
|
||||
return pending
|
||||
|
||||
async def get_settlement_stats(
|
||||
self,
|
||||
bridge_name: Optional[str] = None,
|
||||
time_range: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
async with self._lock:
|
||||
stats = {}
|
||||
|
||||
for settlement in self.settlements.values():
|
||||
if bridge_name and settlement['bridge_name'] != bridge_name:
|
||||
continue
|
||||
|
||||
# TODO: Implement time range filtering
|
||||
|
||||
bridge = settlement['bridge_name']
|
||||
if bridge not in stats:
|
||||
stats[bridge] = {}
|
||||
|
||||
status = settlement['status']
|
||||
if status not in stats[bridge]:
|
||||
stats[bridge][status] = {
|
||||
'count': 0,
|
||||
'avg_amount': 0,
|
||||
'total_amount': 0
|
||||
}
|
||||
|
||||
stats[bridge][status]['count'] += 1
|
||||
stats[bridge][status]['total_amount'] += settlement['payment_amount']
|
||||
|
||||
# Calculate averages
|
||||
for bridge_data in stats.values():
|
||||
for status_data in bridge_data.values():
|
||||
if status_data['count'] > 0:
|
||||
status_data['avg_amount'] = status_data['total_amount'] / status_data['count']
|
||||
|
||||
return stats
|
||||
|
||||
async def cleanup_old_settlements(self, days: int = 30) -> int:
|
||||
async with self._lock:
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
to_delete = [
|
||||
msg_id for msg_id, settlement in self.settlements.items()
|
||||
if (
|
||||
settlement['status'] in ['completed', 'failed'] and
|
||||
settlement['created_at'] < cutoff
|
||||
)
|
||||
]
|
||||
|
||||
for msg_id in to_delete:
|
||||
del self.settlements[msg_id]
|
||||
|
||||
return len(to_delete)
|
||||
@ -0,0 +1,75 @@
|
||||
"""Add settlements table for cross-chain settlements
|
||||
|
||||
Revision ID: 2024_01_10_add_settlements_table
|
||||
Revises: 2024_01_05_add_receipts_table
|
||||
Create Date: 2025-01-10 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2024_01_10_add_settlements_table'
|
||||
down_revision = '2024_01_05_add_receipts_table'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create settlements table
|
||||
op.create_table(
|
||||
'settlements',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('message_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('job_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('source_chain_id', sa.Integer(), nullable=False),
|
||||
sa.Column('target_chain_id', sa.Integer(), nullable=False),
|
||||
sa.Column('receipt_hash', sa.String(length=66), nullable=True),
|
||||
sa.Column('proof_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('payment_amount', sa.Numeric(precision=36, scale=18), nullable=True),
|
||||
sa.Column('payment_token', sa.String(length=42), nullable=True),
|
||||
sa.Column('nonce', sa.BigInteger(), nullable=False),
|
||||
sa.Column('signature', sa.String(length=132), nullable=True),
|
||||
sa.Column('bridge_name', sa.String(length=50), nullable=False),
|
||||
sa.Column('status', sa.String(length=20), nullable=False),
|
||||
sa.Column('transaction_hash', sa.String(length=66), nullable=True),
|
||||
sa.Column('gas_used', sa.BigInteger(), nullable=True),
|
||||
sa.Column('fee_paid', sa.Numeric(precision=36, scale=18), nullable=True),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('message_id')
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.create_index('ix_settlements_job_id', 'settlements', ['job_id'])
|
||||
op.create_index('ix_settlements_status', 'settlements', ['status'])
|
||||
op.create_index('ix_settlements_bridge_name', 'settlements', ['bridge_name'])
|
||||
op.create_index('ix_settlements_created_at', 'settlements', ['created_at'])
|
||||
op.create_index('ix_settlements_message_id', 'settlements', ['message_id'])
|
||||
|
||||
# Add foreign key constraint for jobs table
|
||||
op.create_foreign_key(
|
||||
'fk_settlements_job_id',
|
||||
'settlements', 'jobs',
|
||||
['job_id'], ['id'],
|
||||
ondelete='CASCADE'
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Drop foreign key
|
||||
op.drop_constraint('fk_settlements_job_id', 'settlements', type_='foreignkey')
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index('ix_settlements_message_id', table_name='settlements')
|
||||
op.drop_index('ix_settlements_created_at', table_name='settlements')
|
||||
op.drop_index('ix_settlements_bridge_name', table_name='settlements')
|
||||
op.drop_index('ix_settlements_status', table_name='settlements')
|
||||
op.drop_index('ix_settlements_job_id', table_name='settlements')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('settlements')
|
||||
@ -21,6 +21,7 @@ python-dotenv = "^1.0.1"
|
||||
slowapi = "^0.1.8"
|
||||
orjson = "^3.10.0"
|
||||
gunicorn = "^22.0.0"
|
||||
prometheus-client = "^0.19.0"
|
||||
aitbc-crypto = {path = "../../packages/py/aitbc-crypto"}
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
83
apps/coordinator-api/src/app/exceptions.py
Normal file
83
apps/coordinator-api/src/app/exceptions.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""
|
||||
Exception classes for AITBC coordinator
|
||||
"""
|
||||
|
||||
|
||||
class AITBCError(Exception):
|
||||
"""Base exception for all AITBC errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(AITBCError):
|
||||
"""Raised when authentication fails"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(AITBCError):
|
||||
"""Raised when rate limit is exceeded"""
|
||||
def __init__(self, message: str, retry_after: int = None):
|
||||
super().__init__(message)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class APIError(AITBCError):
|
||||
"""Raised when API request fails"""
|
||||
def __init__(self, message: str, status_code: int = None, response: dict = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.response = response
|
||||
|
||||
|
||||
class ConfigurationError(AITBCError):
|
||||
"""Raised when configuration is invalid"""
|
||||
pass
|
||||
|
||||
|
||||
class ConnectorError(AITBCError):
|
||||
"""Raised when connector operation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class PaymentError(ConnectorError):
|
||||
"""Raised when payment operation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(AITBCError):
|
||||
"""Raised when data validation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class WebhookError(AITBCError):
|
||||
"""Raised when webhook processing fails"""
|
||||
pass
|
||||
|
||||
|
||||
class ERPError(ConnectorError):
|
||||
"""Raised when ERP operation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class SyncError(ConnectorError):
|
||||
"""Raised when synchronization fails"""
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(AITBCError):
|
||||
"""Raised when operation times out"""
|
||||
pass
|
||||
|
||||
|
||||
class TenantError(ConnectorError):
|
||||
"""Raised when tenant operation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class QuotaExceededError(ConnectorError):
|
||||
"""Raised when resource quota is exceeded"""
|
||||
pass
|
||||
|
||||
|
||||
class BillingError(ConnectorError):
|
||||
"""Raised when billing operation fails"""
|
||||
pass
|
||||
@ -1,8 +1,9 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from prometheus_client import make_asgi_app
|
||||
|
||||
from .config import settings
|
||||
from .routers import client, miner, admin, marketplace, explorer
|
||||
from .routers import client, miner, admin, marketplace, explorer, services, registry
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@ -25,6 +26,12 @@ def create_app() -> FastAPI:
|
||||
app.include_router(admin, prefix="/v1")
|
||||
app.include_router(marketplace, prefix="/v1")
|
||||
app.include_router(explorer, prefix="/v1")
|
||||
app.include_router(services, prefix="/v1")
|
||||
app.include_router(registry, prefix="/v1")
|
||||
|
||||
# Add Prometheus metrics endpoint
|
||||
metrics_app = make_asgi_app()
|
||||
app.mount("/metrics", metrics_app)
|
||||
|
||||
@app.get("/v1/health", tags=["health"], summary="Service healthcheck")
|
||||
async def health() -> dict[str, str]:
|
||||
|
||||
16
apps/coordinator-api/src/app/metrics.py
Normal file
16
apps/coordinator-api/src/app/metrics.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""Prometheus metrics for the AITBC Coordinator API."""
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
# Marketplace API metrics
|
||||
marketplace_requests_total = Counter(
|
||||
'marketplace_requests_total',
|
||||
'Total number of marketplace API requests',
|
||||
['endpoint', 'method']
|
||||
)
|
||||
|
||||
marketplace_errors_total = Counter(
|
||||
'marketplace_errors_total',
|
||||
'Total number of marketplace API errors',
|
||||
['endpoint', 'method', 'error_type']
|
||||
)
|
||||
292
apps/coordinator-api/src/app/middleware/tenant_context.py
Normal file
292
apps/coordinator-api/src/app/middleware/tenant_context.py
Normal file
@ -0,0 +1,292 @@
|
||||
"""
|
||||
Tenant context middleware for multi-tenant isolation
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Optional, Callable
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import event, select, and_
|
||||
from contextvars import ContextVar
|
||||
|
||||
from ..database import get_db
|
||||
from ..models.multitenant import Tenant, TenantApiKey
|
||||
from ..services.tenant_management import TenantManagementService
|
||||
from ..exceptions import TenantError
|
||||
|
||||
|
||||
# Context variable for current tenant
|
||||
current_tenant: ContextVar[Optional[Tenant]] = ContextVar('current_tenant', default=None)
|
||||
current_tenant_id: ContextVar[Optional[str]] = ContextVar('current_tenant_id', default=None)
|
||||
|
||||
|
||||
def get_current_tenant() -> Optional[Tenant]:
|
||||
"""Get the current tenant from context"""
|
||||
return current_tenant.get()
|
||||
|
||||
|
||||
def get_current_tenant_id() -> Optional[str]:
|
||||
"""Get the current tenant ID from context"""
|
||||
return current_tenant_id.get()
|
||||
|
||||
|
||||
class TenantContextMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to extract and set tenant context"""
|
||||
|
||||
def __init__(self, app, excluded_paths: Optional[list] = None):
|
||||
super().__init__(app)
|
||||
self.excluded_paths = excluded_paths or [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/favicon.ico",
|
||||
"/static"
|
||||
]
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip tenant extraction for excluded paths
|
||||
if self._should_exclude(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract tenant from request
|
||||
tenant = await self._extract_tenant(request)
|
||||
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Tenant not found or invalid"
|
||||
)
|
||||
|
||||
# Check tenant status
|
||||
if tenant.status not in ["active", "trial"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Tenant is {tenant.status}"
|
||||
)
|
||||
|
||||
# Set tenant context
|
||||
current_tenant.set(tenant)
|
||||
current_tenant_id.set(str(tenant.id))
|
||||
|
||||
# Add tenant to request state for easy access
|
||||
request.state.tenant = tenant
|
||||
request.state.tenant_id = str(tenant.id)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Clear context
|
||||
current_tenant.set(None)
|
||||
current_tenant_id.set(None)
|
||||
|
||||
return response
|
||||
|
||||
def _should_exclude(self, path: str) -> bool:
|
||||
"""Check if path should be excluded from tenant extraction"""
|
||||
for excluded in self.excluded_paths:
|
||||
if path.startswith(excluded):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _extract_tenant(self, request: Request) -> Optional[Tenant]:
|
||||
"""Extract tenant from request using various methods"""
|
||||
|
||||
# Method 1: Subdomain
|
||||
tenant = await self._extract_from_subdomain(request)
|
||||
if tenant:
|
||||
return tenant
|
||||
|
||||
# Method 2: Custom header
|
||||
tenant = await self._extract_from_header(request)
|
||||
if tenant:
|
||||
return tenant
|
||||
|
||||
# Method 3: API key
|
||||
tenant = await self._extract_from_api_key(request)
|
||||
if tenant:
|
||||
return tenant
|
||||
|
||||
# Method 4: JWT token (if using OAuth)
|
||||
tenant = await self._extract_from_token(request)
|
||||
if tenant:
|
||||
return tenant
|
||||
|
||||
return None
|
||||
|
||||
async def _extract_from_subdomain(self, request: Request) -> Optional[Tenant]:
|
||||
"""Extract tenant from subdomain"""
|
||||
host = request.headers.get("host", "").split(":")[0]
|
||||
|
||||
# Split hostname to get subdomain
|
||||
parts = host.split(".")
|
||||
if len(parts) > 2:
|
||||
subdomain = parts[0]
|
||||
|
||||
# Skip common subdomains
|
||||
if subdomain in ["www", "api", "admin", "app"]:
|
||||
return None
|
||||
|
||||
# Look up tenant by subdomain/slug
|
||||
db = next(get_db())
|
||||
try:
|
||||
service = TenantManagementService(db)
|
||||
return await service.get_tenant_by_slug(subdomain)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return None
|
||||
|
||||
async def _extract_from_header(self, request: Request) -> Optional[Tenant]:
|
||||
"""Extract tenant from custom header"""
|
||||
tenant_id = request.headers.get("X-Tenant-ID")
|
||||
if not tenant_id:
|
||||
return None
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
service = TenantManagementService(db)
|
||||
return await service.get_tenant(tenant_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _extract_from_api_key(self, request: Request) -> Optional[Tenant]:
|
||||
"""Extract tenant from API key"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
api_key = auth_header[7:] # Remove "Bearer "
|
||||
|
||||
# Hash the key to compare with stored hash
|
||||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
# Look up API key
|
||||
stmt = select(TenantApiKey).where(
|
||||
and_(
|
||||
TenantApiKey.key_hash == key_hash,
|
||||
TenantApiKey.is_active == True
|
||||
)
|
||||
)
|
||||
api_key_record = db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not api_key_record:
|
||||
return None
|
||||
|
||||
# Check if key has expired
|
||||
if api_key_record.expires_at and api_key_record.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
# Update last used timestamp
|
||||
api_key_record.last_used_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
# Get tenant
|
||||
service = TenantManagementService(db)
|
||||
return await service.get_tenant(str(api_key_record.tenant_id))
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _extract_from_token(self, request: Request) -> Optional[Tenant]:
|
||||
"""Extract tenant from JWT token"""
|
||||
# TODO: Implement JWT token extraction
|
||||
# This would decode the JWT and extract tenant_id from claims
|
||||
return None
|
||||
|
||||
|
||||
class TenantRowLevelSecurity:
|
||||
"""Row-level security implementation for tenant isolation"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
def enable_rls(self):
|
||||
"""Enable row-level security for the session"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Set session variable for PostgreSQL RLS
|
||||
self.db.execute(
|
||||
"SET SESSION aitbc.current_tenant_id = :tenant_id",
|
||||
{"tenant_id": tenant_id}
|
||||
)
|
||||
|
||||
self.logger.debug(f"Enabled RLS for tenant: {tenant_id}")
|
||||
|
||||
def disable_rls(self):
|
||||
"""Disable row-level security for the session"""
|
||||
self.db.execute("RESET aitbc.current_tenant_id")
|
||||
self.logger.debug("Disabled RLS")
|
||||
|
||||
|
||||
# Database event listeners for automatic RLS
|
||||
@event.listens_for(Session, "after_begin")
|
||||
def on_session_begin(session, transaction):
|
||||
"""Enable RLS when session begins"""
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
if tenant_id:
|
||||
session.execute(
|
||||
"SET SESSION aitbc.current_tenant_id = :tenant_id",
|
||||
{"tenant_id": tenant_id}
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail
|
||||
logger = __import__('logging').getLogger(__name__)
|
||||
logger.error(f"Failed to set tenant context: {e}")
|
||||
|
||||
|
||||
# Decorator for tenant-aware endpoints
|
||||
def requires_tenant(func):
|
||||
"""Decorator to ensure tenant context is present"""
|
||||
async def wrapper(*args, **kwargs):
|
||||
tenant = get_current_tenant()
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Tenant context required"
|
||||
)
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
# Dependency for FastAPI
|
||||
async def get_current_tenant_dependency(request: Request) -> Tenant:
|
||||
"""FastAPI dependency to get current tenant"""
|
||||
tenant = getattr(request.state, "tenant", None)
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Tenant not found"
|
||||
)
|
||||
return tenant
|
||||
|
||||
|
||||
# Utility functions
|
||||
def with_tenant_context(tenant_id: str):
|
||||
"""Execute code with specific tenant context"""
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
def is_tenant_admin(user_permissions: list) -> bool:
|
||||
"""Check if user has tenant admin permissions"""
|
||||
return "tenant:admin" in user_permissions or "admin" in user_permissions
|
||||
|
||||
|
||||
def has_tenant_permission(permission: str, user_permissions: list) -> bool:
|
||||
"""Check if user has specific tenant permission"""
|
||||
return permission in user_permissions or "tenant:admin" in user_permissions
|
||||
@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List
|
||||
from base64 import b64encode, b64decode
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
@ -170,3 +171,176 @@ class ReceiptListResponse(BaseModel):
|
||||
|
||||
jobId: str
|
||||
items: list[ReceiptSummary]
|
||||
|
||||
|
||||
# Confidential Transaction Models
|
||||
|
||||
class ConfidentialTransaction(BaseModel):
|
||||
"""Transaction with optional confidential fields"""
|
||||
|
||||
# Public fields (always visible)
|
||||
transaction_id: str
|
||||
job_id: str
|
||||
timestamp: datetime
|
||||
status: str
|
||||
|
||||
# Confidential fields (encrypted when opt-in)
|
||||
amount: Optional[str] = None
|
||||
pricing: Optional[Dict[str, Any]] = None
|
||||
settlement_details: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Encryption metadata
|
||||
confidential: bool = False
|
||||
encrypted_data: Optional[str] = None # Base64 encoded
|
||||
encrypted_keys: Optional[Dict[str, str]] = None # Base64 encoded
|
||||
algorithm: Optional[str] = None
|
||||
|
||||
# Access control
|
||||
participants: List[str] = []
|
||||
access_policies: Dict[str, Any] = {}
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class ConfidentialTransactionCreate(BaseModel):
|
||||
"""Request to create confidential transaction"""
|
||||
|
||||
job_id: str
|
||||
amount: Optional[str] = None
|
||||
pricing: Optional[Dict[str, Any]] = None
|
||||
settlement_details: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Privacy options
|
||||
confidential: bool = False
|
||||
participants: List[str] = []
|
||||
|
||||
# Access policies
|
||||
access_policies: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class ConfidentialTransactionView(BaseModel):
|
||||
"""Response for confidential transaction view"""
|
||||
|
||||
transaction_id: str
|
||||
job_id: str
|
||||
timestamp: datetime
|
||||
status: str
|
||||
|
||||
# Decrypted fields (only if authorized)
|
||||
amount: Optional[str] = None
|
||||
pricing: Optional[Dict[str, Any]] = None
|
||||
settlement_details: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Metadata
|
||||
confidential: bool
|
||||
participants: List[str]
|
||||
has_encrypted_data: bool
|
||||
|
||||
|
||||
class ConfidentialAccessRequest(BaseModel):
|
||||
"""Request to access confidential transaction data"""
|
||||
|
||||
transaction_id: str
|
||||
requester: str
|
||||
purpose: str
|
||||
justification: Optional[str] = None
|
||||
|
||||
|
||||
class ConfidentialAccessResponse(BaseModel):
|
||||
"""Response for confidential data access"""
|
||||
|
||||
success: bool
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
access_id: Optional[str] = None
|
||||
|
||||
|
||||
# Key Management Models
|
||||
|
||||
class KeyPair(BaseModel):
|
||||
"""Encryption key pair for participant"""
|
||||
|
||||
participant_id: str
|
||||
private_key: bytes
|
||||
public_key: bytes
|
||||
algorithm: str = "X25519"
|
||||
created_at: datetime
|
||||
version: int = 1
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class KeyRotationLog(BaseModel):
|
||||
"""Log of key rotation events"""
|
||||
|
||||
participant_id: str
|
||||
old_version: int
|
||||
new_version: int
|
||||
rotated_at: datetime
|
||||
reason: str
|
||||
|
||||
|
||||
class AuditAuthorization(BaseModel):
|
||||
"""Authorization for audit access"""
|
||||
|
||||
issuer: str
|
||||
subject: str
|
||||
purpose: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
signature: str
|
||||
|
||||
|
||||
class KeyRegistrationRequest(BaseModel):
|
||||
"""Request to register encryption keys"""
|
||||
|
||||
participant_id: str
|
||||
public_key: str # Base64 encoded
|
||||
algorithm: str = "X25519"
|
||||
|
||||
|
||||
class KeyRegistrationResponse(BaseModel):
|
||||
"""Response for key registration"""
|
||||
|
||||
success: bool
|
||||
participant_id: str
|
||||
key_version: int
|
||||
registered_at: datetime
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# Access Log Models
|
||||
|
||||
class ConfidentialAccessLog(BaseModel):
|
||||
"""Audit log for confidential data access"""
|
||||
|
||||
transaction_id: Optional[str]
|
||||
participant_id: str
|
||||
purpose: str
|
||||
timestamp: datetime
|
||||
authorized_by: str
|
||||
data_accessed: List[str]
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
|
||||
|
||||
class AccessLogQuery(BaseModel):
|
||||
"""Query for access logs"""
|
||||
|
||||
transaction_id: Optional[str] = None
|
||||
participant_id: Optional[str] = None
|
||||
purpose: Optional[str] = None
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
limit: int = 100
|
||||
offset: int = 0
|
||||
|
||||
|
||||
class AccessLogResponse(BaseModel):
|
||||
"""Response for access log query"""
|
||||
|
||||
logs: List[ConfidentialAccessLog]
|
||||
total_count: int
|
||||
has_more: bool
|
||||
|
||||
169
apps/coordinator-api/src/app/models/confidential.py
Normal file
169
apps/coordinator-api/src/app/models/confidential.py
Normal file
@ -0,0 +1,169 @@
|
||||
"""
|
||||
Database models for confidential transactions
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy import Column, String, DateTime, Boolean, Text, JSON, Integer, LargeBinary
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.sql import func
|
||||
import uuid
|
||||
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class ConfidentialTransactionDB(Base):
|
||||
"""Database model for confidential transactions"""
|
||||
__tablename__ = "confidential_transactions"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Public fields (always visible)
|
||||
transaction_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||
job_id = Column(String(255), nullable=False, index=True)
|
||||
timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
status = Column(String(50), nullable=False, default="created")
|
||||
|
||||
# Encryption metadata
|
||||
confidential = Column(Boolean, nullable=False, default=False)
|
||||
algorithm = Column(String(50), nullable=True)
|
||||
|
||||
# Encrypted data (stored as binary)
|
||||
encrypted_data = Column(LargeBinary, nullable=True)
|
||||
encrypted_nonce = Column(LargeBinary, nullable=True)
|
||||
encrypted_tag = Column(LargeBinary, nullable=True)
|
||||
|
||||
# Encrypted keys for participants (JSON encoded)
|
||||
encrypted_keys = Column(JSON, nullable=True)
|
||||
participants = Column(JSON, nullable=True)
|
||||
|
||||
# Access policies
|
||||
access_policies = Column(JSON, nullable=True)
|
||||
|
||||
# Audit fields
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
created_by = Column(String(255), nullable=True)
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class ParticipantKeyDB(Base):
|
||||
"""Database model for participant encryption keys"""
|
||||
__tablename__ = "participant_keys"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
participant_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# Key data (encrypted at rest)
|
||||
encrypted_private_key = Column(LargeBinary, nullable=False)
|
||||
public_key = Column(LargeBinary, nullable=False)
|
||||
|
||||
# Key metadata
|
||||
algorithm = Column(String(50), nullable=False, default="X25519")
|
||||
version = Column(Integer, nullable=False, default=1)
|
||||
|
||||
# Status
|
||||
active = Column(Boolean, nullable=False, default=True)
|
||||
revoked_at = Column(DateTime(timezone=True), nullable=True)
|
||||
revoke_reason = Column(String(255), nullable=True)
|
||||
|
||||
# Audit fields
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
rotated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class ConfidentialAccessLogDB(Base):
|
||||
"""Database model for confidential data access logs"""
|
||||
__tablename__ = "confidential_access_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Access details
|
||||
transaction_id = Column(String(255), nullable=True, index=True)
|
||||
participant_id = Column(String(255), nullable=False, index=True)
|
||||
purpose = Column(String(100), nullable=False)
|
||||
|
||||
# Request details
|
||||
action = Column(String(100), nullable=False)
|
||||
resource = Column(String(100), nullable=False)
|
||||
outcome = Column(String(50), nullable=False)
|
||||
|
||||
# Additional data
|
||||
details = Column(JSON, nullable=True)
|
||||
data_accessed = Column(JSON, nullable=True)
|
||||
|
||||
# Metadata
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
user_agent = Column(Text, nullable=True)
|
||||
authorization_id = Column(String(255), nullable=True)
|
||||
|
||||
# Integrity
|
||||
signature = Column(String(128), nullable=True) # SHA-512 hash
|
||||
|
||||
# Timestamps
|
||||
timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class KeyRotationLogDB(Base):
|
||||
"""Database model for key rotation logs"""
|
||||
__tablename__ = "key_rotation_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
participant_id = Column(String(255), nullable=False, index=True)
|
||||
old_version = Column(Integer, nullable=False)
|
||||
new_version = Column(Integer, nullable=False)
|
||||
|
||||
# Rotation details
|
||||
rotated_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
reason = Column(String(255), nullable=False)
|
||||
|
||||
# Who performed the rotation
|
||||
rotated_by = Column(String(255), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class AuditAuthorizationDB(Base):
|
||||
"""Database model for audit authorizations"""
|
||||
__tablename__ = "audit_authorizations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Authorization details
|
||||
issuer = Column(String(255), nullable=False)
|
||||
subject = Column(String(255), nullable=False)
|
||||
purpose = Column(String(100), nullable=False)
|
||||
|
||||
# Validity period
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
# Authorization data
|
||||
signature = Column(String(512), nullable=False)
|
||||
metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Status
|
||||
active = Column(Boolean, nullable=False, default=True)
|
||||
revoked_at = Column(DateTime(timezone=True), nullable=True)
|
||||
used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
340
apps/coordinator-api/src/app/models/multitenant.py
Normal file
340
apps/coordinator-api/src/app/models/multitenant.py
Normal file
@ -0,0 +1,340 @@
|
||||
"""
|
||||
Multi-tenant data models for AITBC coordinator
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, String, DateTime, Boolean, Integer, Text, JSON, ForeignKey, Index, Numeric
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
import uuid
|
||||
|
||||
from ..database import Base
|
||||
|
||||
|
||||
class TenantStatus(Enum):
|
||||
"""Tenant status enumeration"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
SUSPENDED = "suspended"
|
||||
PENDING = "pending"
|
||||
TRIAL = "trial"
|
||||
|
||||
|
||||
class Tenant(Base):
|
||||
"""Tenant model for multi-tenancy"""
|
||||
__tablename__ = "tenants"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Tenant information
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(100), unique=True, nullable=False, index=True)
|
||||
domain = Column(String(255), unique=True, nullable=True, index=True)
|
||||
|
||||
# Status and configuration
|
||||
status = Column(String(50), nullable=False, default=TenantStatus.PENDING.value)
|
||||
plan = Column(String(50), nullable=False, default="trial")
|
||||
|
||||
# Contact information
|
||||
contact_email = Column(String(255), nullable=False)
|
||||
billing_email = Column(String(255), nullable=True)
|
||||
|
||||
# Configuration
|
||||
settings = Column(JSON, nullable=False, default={})
|
||||
features = Column(JSON, nullable=False, default={})
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
activated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
deactivated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
users = relationship("TenantUser", back_populates="tenant", cascade="all, delete-orphan")
|
||||
quotas = relationship("TenantQuota", back_populates="tenant", cascade="all, delete-orphan")
|
||||
usage_records = relationship("UsageRecord", back_populates="tenant", cascade="all, delete-orphan")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_tenant_status', 'status'),
|
||||
Index('idx_tenant_plan', 'plan'),
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class TenantUser(Base):
|
||||
"""Association between users and tenants"""
|
||||
__tablename__ = "tenant_users"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign keys
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
|
||||
user_id = Column(String(255), nullable=False) # User ID from auth system
|
||||
|
||||
# Role and permissions
|
||||
role = Column(String(50), nullable=False, default="member")
|
||||
permissions = Column(JSON, nullable=False, default=[])
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
invited_at = Column(DateTime(timezone=True), nullable=True)
|
||||
joined_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Metadata
|
||||
metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
tenant = relationship("Tenant", back_populates="users")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_tenant_user', 'tenant_id', 'user_id'),
|
||||
Index('idx_user_tenants', 'user_id'),
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class TenantQuota(Base):
|
||||
"""Resource quotas for tenants"""
|
||||
__tablename__ = "tenant_quotas"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign key
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
|
||||
|
||||
# Quota definitions
|
||||
resource_type = Column(String(100), nullable=False) # gpu_hours, storage_gb, api_calls
|
||||
limit_value = Column(Numeric(20, 4), nullable=False) # Maximum allowed
|
||||
used_value = Column(Numeric(20, 4), nullable=False, default=0) # Current usage
|
||||
|
||||
# Time period
|
||||
period_type = Column(String(50), nullable=False, default="monthly") # daily, weekly, monthly
|
||||
period_start = Column(DateTime(timezone=True), nullable=False)
|
||||
period_end = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# Relationships
|
||||
tenant = relationship("Tenant", back_populates="quotas")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_tenant_quota', 'tenant_id', 'resource_type', 'period_start'),
|
||||
Index('idx_quota_period', 'period_start', 'period_end'),
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class UsageRecord(Base):
|
||||
"""Usage tracking records for billing"""
|
||||
__tablename__ = "usage_records"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign key
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
|
||||
|
||||
# Usage details
|
||||
resource_type = Column(String(100), nullable=False) # gpu_hours, storage_gb, api_calls
|
||||
resource_id = Column(String(255), nullable=True) # Specific resource ID
|
||||
quantity = Column(Numeric(20, 4), nullable=False)
|
||||
unit = Column(String(50), nullable=False) # hours, gb, calls
|
||||
|
||||
# Cost information
|
||||
unit_price = Column(Numeric(10, 4), nullable=False)
|
||||
total_cost = Column(Numeric(20, 4), nullable=False)
|
||||
currency = Column(String(10), nullable=False, default="USD")
|
||||
|
||||
# Time tracking
|
||||
usage_start = Column(DateTime(timezone=True), nullable=False)
|
||||
usage_end = Column(DateTime(timezone=True), nullable=False)
|
||||
recorded_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
# Metadata
|
||||
job_id = Column(String(255), nullable=True) # Associated job if applicable
|
||||
metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
tenant = relationship("Tenant", back_populates="usage_records")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_tenant_usage', 'tenant_id', 'usage_start'),
|
||||
Index('idx_usage_type', 'resource_type', 'usage_start'),
|
||||
Index('idx_usage_job', 'job_id'),
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class Invoice(Base):
|
||||
"""Billing invoices for tenants"""
|
||||
__tablename__ = "invoices"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign key
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
|
||||
|
||||
# Invoice details
|
||||
invoice_number = Column(String(100), unique=True, nullable=False, index=True)
|
||||
status = Column(String(50), nullable=False, default="draft")
|
||||
|
||||
# Period
|
||||
period_start = Column(DateTime(timezone=True), nullable=False)
|
||||
period_end = Column(DateTime(timezone=True), nullable=False)
|
||||
due_date = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Amounts
|
||||
subtotal = Column(Numeric(20, 4), nullable=False)
|
||||
tax_amount = Column(Numeric(20, 4), nullable=False, default=0)
|
||||
total_amount = Column(Numeric(20, 4), nullable=False)
|
||||
currency = Column(String(10), nullable=False, default="USD")
|
||||
|
||||
# Breakdown
|
||||
line_items = Column(JSON, nullable=False, default=[])
|
||||
|
||||
# Payment
|
||||
paid_at = Column(DateTime(timezone=True), nullable=True)
|
||||
payment_method = Column(String(100), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
# Metadata
|
||||
metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_invoice_tenant', 'tenant_id', 'period_start'),
|
||||
Index('idx_invoice_status', 'status'),
|
||||
Index('idx_invoice_due', 'due_date'),
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class TenantApiKey(Base):
|
||||
"""API keys for tenant authentication"""
|
||||
__tablename__ = "tenant_api_keys"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign key
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
|
||||
|
||||
# Key details
|
||||
key_id = Column(String(100), unique=True, nullable=False, index=True)
|
||||
key_hash = Column(String(255), unique=True, nullable=False, index=True)
|
||||
key_prefix = Column(String(20), nullable=False) # First few characters for identification
|
||||
|
||||
# Permissions and restrictions
|
||||
permissions = Column(JSON, nullable=False, default=[])
|
||||
rate_limit = Column(Integer, nullable=True) # Requests per minute
|
||||
allowed_ips = Column(JSON, nullable=True) # IP whitelist
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Metadata
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
created_by = Column(String(255), nullable=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
revoked_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_api_key_tenant', 'tenant_id', 'is_active'),
|
||||
Index('idx_api_key_hash', 'key_hash'),
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class TenantAuditLog(Base):
|
||||
"""Audit logs for tenant activities"""
|
||||
__tablename__ = "tenant_audit_logs"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign key
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
|
||||
|
||||
# Event details
|
||||
event_type = Column(String(100), nullable=False, index=True)
|
||||
event_category = Column(String(50), nullable=False, index=True)
|
||||
actor_id = Column(String(255), nullable=False) # User who performed action
|
||||
actor_type = Column(String(50), nullable=False) # user, api_key, system
|
||||
|
||||
# Target information
|
||||
resource_type = Column(String(100), nullable=False)
|
||||
resource_id = Column(String(255), nullable=True)
|
||||
|
||||
# Event data
|
||||
old_values = Column(JSON, nullable=True)
|
||||
new_values = Column(JSON, nullable=True)
|
||||
metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Request context
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
user_agent = Column(Text, nullable=True)
|
||||
api_key_id = Column(String(100), nullable=True)
|
||||
|
||||
# Timestamp
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_audit_tenant', 'tenant_id', 'created_at'),
|
||||
Index('idx_audit_actor', 'actor_id', 'event_type'),
|
||||
Index('idx_audit_resource', 'resource_type', 'resource_id'),
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
|
||||
|
||||
class TenantMetric(Base):
|
||||
"""Tenant-specific metrics and monitoring data"""
|
||||
__tablename__ = "tenant_metrics"
|
||||
|
||||
# Primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign key
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
|
||||
|
||||
# Metric details
|
||||
metric_name = Column(String(100), nullable=False, index=True)
|
||||
metric_type = Column(String(50), nullable=False) # counter, gauge, histogram
|
||||
|
||||
# Value
|
||||
value = Column(Numeric(20, 4), nullable=False)
|
||||
unit = Column(String(50), nullable=True)
|
||||
|
||||
# Dimensions
|
||||
dimensions = Column(JSON, nullable=False, default={})
|
||||
|
||||
# Time
|
||||
timestamp = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_metric_tenant', 'tenant_id', 'metric_name', 'timestamp'),
|
||||
Index('idx_metric_time', 'timestamp'),
|
||||
{'schema': 'aitbc'}
|
||||
)
|
||||
547
apps/coordinator-api/src/app/models/registry.py
Normal file
547
apps/coordinator-api/src/app/models/registry.py
Normal file
@ -0,0 +1,547 @@
|
||||
"""
|
||||
Dynamic service registry models for AITBC
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class ServiceCategory(str, Enum):
|
||||
"""Service categories"""
|
||||
AI_ML = "ai_ml"
|
||||
MEDIA_PROCESSING = "media_processing"
|
||||
SCIENTIFIC_COMPUTING = "scientific_computing"
|
||||
DATA_ANALYTICS = "data_analytics"
|
||||
GAMING_ENTERTAINMENT = "gaming_entertainment"
|
||||
DEVELOPMENT_TOOLS = "development_tools"
|
||||
|
||||
|
||||
class ParameterType(str, Enum):
|
||||
"""Parameter types"""
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
FLOAT = "float"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
ENUM = "enum"
|
||||
|
||||
|
||||
class PricingModel(str, Enum):
|
||||
"""Pricing models"""
|
||||
PER_UNIT = "per_unit" # per image, per minute, per token
|
||||
PER_HOUR = "per_hour"
|
||||
PER_GB = "per_gb"
|
||||
PER_FRAME = "per_frame"
|
||||
FIXED = "fixed"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ParameterDefinition(BaseModel):
|
||||
"""Parameter definition schema"""
|
||||
name: str = Field(..., description="Parameter name")
|
||||
type: ParameterType = Field(..., description="Parameter type")
|
||||
required: bool = Field(True, description="Whether parameter is required")
|
||||
description: str = Field(..., description="Parameter description")
|
||||
default: Optional[Any] = Field(None, description="Default value")
|
||||
min_value: Optional[Union[int, float]] = Field(None, description="Minimum value")
|
||||
max_value: Optional[Union[int, float]] = Field(None, description="Maximum value")
|
||||
options: Optional[List[str]] = Field(None, description="Available options for enum type")
|
||||
validation: Optional[Dict[str, Any]] = Field(None, description="Custom validation rules")
|
||||
|
||||
|
||||
class HardwareRequirement(BaseModel):
|
||||
"""Hardware requirement definition"""
|
||||
component: str = Field(..., description="Component type (gpu, cpu, ram, etc.)")
|
||||
min_value: Union[str, int, float] = Field(..., description="Minimum requirement")
|
||||
recommended: Optional[Union[str, int, float]] = Field(None, description="Recommended value")
|
||||
unit: Optional[str] = Field(None, description="Unit (GB, MB, cores, etc.)")
|
||||
|
||||
|
||||
class PricingTier(BaseModel):
|
||||
"""Pricing tier definition"""
|
||||
name: str = Field(..., description="Tier name")
|
||||
model: PricingModel = Field(..., description="Pricing model")
|
||||
unit_price: float = Field(..., ge=0, description="Price per unit")
|
||||
min_charge: Optional[float] = Field(None, ge=0, description="Minimum charge")
|
||||
currency: str = Field("AITBC", description="Currency code")
|
||||
description: Optional[str] = Field(None, description="Tier description")
|
||||
|
||||
|
||||
class ServiceDefinition(BaseModel):
|
||||
"""Complete service definition"""
|
||||
id: str = Field(..., description="Unique service identifier")
|
||||
name: str = Field(..., description="Human-readable service name")
|
||||
category: ServiceCategory = Field(..., description="Service category")
|
||||
description: str = Field(..., description="Service description")
|
||||
version: str = Field("1.0.0", description="Service version")
|
||||
icon: Optional[str] = Field(None, description="Icon emoji or URL")
|
||||
|
||||
# Input/Output
|
||||
input_parameters: List[ParameterDefinition] = Field(..., description="Input parameters")
|
||||
output_schema: Dict[str, Any] = Field(..., description="Output schema")
|
||||
|
||||
# Hardware requirements
|
||||
requirements: List[HardwareRequirement] = Field(..., description="Hardware requirements")
|
||||
|
||||
# Pricing
|
||||
pricing: List[PricingTier] = Field(..., description="Available pricing tiers")
|
||||
|
||||
# Capabilities
|
||||
capabilities: List[str] = Field(default_factory=list, description="Service capabilities")
|
||||
tags: List[str] = Field(default_factory=list, description="Search tags")
|
||||
|
||||
# Limits
|
||||
max_concurrent: int = Field(1, ge=1, le=100, description="Max concurrent jobs")
|
||||
timeout_seconds: int = Field(3600, ge=60, description="Default timeout")
|
||||
|
||||
# Metadata
|
||||
provider: Optional[str] = Field(None, description="Service provider")
|
||||
documentation_url: Optional[str] = Field(None, description="Documentation URL")
|
||||
example_usage: Optional[Dict[str, Any]] = Field(None, description="Example usage")
|
||||
|
||||
@validator('id')
|
||||
def validate_id(cls, v):
|
||||
if not v or not v.replace('_', '').replace('-', '').isalnum():
|
||||
raise ValueError('Service ID must contain only alphanumeric characters, hyphens, and underscores')
|
||||
return v.lower()
|
||||
|
||||
|
||||
class ServiceRegistry(BaseModel):
|
||||
"""Service registry containing all available services"""
|
||||
version: str = Field("1.0.0", description="Registry version")
|
||||
last_updated: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||
services: Dict[str, ServiceDefinition] = Field(..., description="Service definitions by ID")
|
||||
|
||||
def get_service(self, service_id: str) -> Optional[ServiceDefinition]:
|
||||
"""Get service by ID"""
|
||||
return self.services.get(service_id)
|
||||
|
||||
def get_services_by_category(self, category: ServiceCategory) -> List[ServiceDefinition]:
|
||||
"""Get all services in a category"""
|
||||
return [s for s in self.services.values() if s.category == category]
|
||||
|
||||
def search_services(self, query: str) -> List[ServiceDefinition]:
|
||||
"""Search services by name, description, or tags"""
|
||||
query = query.lower()
|
||||
results = []
|
||||
|
||||
for service in self.services.values():
|
||||
if (query in service.name.lower() or
|
||||
query in service.description.lower() or
|
||||
any(query in tag.lower() for tag in service.tags)):
|
||||
results.append(service)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Predefined service templates
|
||||
AI_ML_SERVICES = {
|
||||
"llm_inference": ServiceDefinition(
|
||||
id="llm_inference",
|
||||
name="LLM Inference",
|
||||
category=ServiceCategory.AI_ML,
|
||||
description="Run inference on large language models",
|
||||
icon="🤖",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="model",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Model to use for inference",
|
||||
options=["llama-7b", "llama-13b", "llama-70b", "mistral-7b", "mixtral-8x7b", "codellama-7b", "codellama-13b", "codellama-34b", "falcon-7b", "falcon-40b"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="prompt",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Input prompt text",
|
||||
min_value=1,
|
||||
max_value=10000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="max_tokens",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Maximum tokens to generate",
|
||||
default=256,
|
||||
min_value=1,
|
||||
max_value=4096
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="temperature",
|
||||
type=ParameterType.FLOAT,
|
||||
required=False,
|
||||
description="Sampling temperature",
|
||||
default=0.7,
|
||||
min_value=0.0,
|
||||
max_value=2.0
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="stream",
|
||||
type=ParameterType.BOOLEAN,
|
||||
required=False,
|
||||
description="Stream response",
|
||||
default=False
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"tokens_used": {"type": "integer"},
|
||||
"finish_reason": {"type": "string"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-4090"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=24, unit="GB"),
|
||||
HardwareRequirement(component="cuda", min_value="11.8")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="basic", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
|
||||
PricingTier(name="premium", model=PricingModel.PER_UNIT, unit_price=0.002, min_charge=0.01)
|
||||
],
|
||||
capabilities=["generate", "stream", "chat", "completion"],
|
||||
tags=["llm", "text", "generation", "ai", "nlp"],
|
||||
max_concurrent=2,
|
||||
timeout_seconds=300
|
||||
),
|
||||
|
||||
"image_generation": ServiceDefinition(
|
||||
id="image_generation",
|
||||
name="Image Generation",
|
||||
category=ServiceCategory.AI_ML,
|
||||
description="Generate images from text prompts using diffusion models",
|
||||
icon="🎨",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="model",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Image generation model",
|
||||
options=["stable-diffusion-1.5", "stable-diffusion-2.1", "stable-diffusion-xl", "sdxl-turbo", "dall-e-2", "dall-e-3", "midjourney-v5"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="prompt",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Text prompt for image generation",
|
||||
max_value=1000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="negative_prompt",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Negative prompt",
|
||||
max_value=1000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="width",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Image width",
|
||||
default=512,
|
||||
options=[256, 512, 768, 1024, 1536, 2048]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="height",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Image height",
|
||||
default=512,
|
||||
options=[256, 512, 768, 1024, 1536, 2048]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="num_images",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of images to generate",
|
||||
default=1,
|
||||
min_value=1,
|
||||
max_value=4
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="steps",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of inference steps",
|
||||
default=20,
|
||||
min_value=1,
|
||||
max_value=100
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"images": {"type": "array", "items": {"type": "string"}},
|
||||
"parameters": {"type": "object"},
|
||||
"generation_time": {"type": "number"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-4090"),
|
||||
HardwareRequirement(component="vram", min_value=4, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="cuda", min_value="11.8")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="standard", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.01),
|
||||
PricingTier(name="hd", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.02),
|
||||
PricingTier(name="4k", model=PricingModel.PER_UNIT, unit_price=0.05, min_charge=0.05)
|
||||
],
|
||||
capabilities=["txt2img", "img2img", "inpainting", "outpainting"],
|
||||
tags=["image", "generation", "diffusion", "ai", "art"],
|
||||
max_concurrent=1,
|
||||
timeout_seconds=600
|
||||
),
|
||||
|
||||
"video_generation": ServiceDefinition(
|
||||
id="video_generation",
|
||||
name="Video Generation",
|
||||
category=ServiceCategory.AI_ML,
|
||||
description="Generate videos from text or images",
|
||||
icon="🎬",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="model",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Video generation model",
|
||||
options=["sora", "runway-gen2", "pika-labs", "stable-video-diffusion", "make-a-video"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="prompt",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Text prompt for video generation",
|
||||
max_value=500
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="duration_seconds",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Video duration in seconds",
|
||||
default=4,
|
||||
min_value=1,
|
||||
max_value=30
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="fps",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Frames per second",
|
||||
default=24,
|
||||
options=[12, 24, 30]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="resolution",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Video resolution",
|
||||
default="720p",
|
||||
options=["480p", "720p", "1080p", "4k"]
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"video_url": {"type": "string"},
|
||||
"thumbnail_url": {"type": "string"},
|
||||
"duration": {"type": "number"},
|
||||
"resolution": {"type": "string"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
|
||||
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
|
||||
HardwareRequirement(component="cuda", min_value="11.8")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="short", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=0.1),
|
||||
PricingTier(name="medium", model=PricingModel.PER_UNIT, unit_price=0.25, min_charge=0.25),
|
||||
PricingTier(name="long", model=PricingModel.PER_UNIT, unit_price=0.5, min_charge=0.5)
|
||||
],
|
||||
capabilities=["txt2video", "img2video", "video-editing"],
|
||||
tags=["video", "generation", "ai", "animation"],
|
||||
max_concurrent=1,
|
||||
timeout_seconds=1800
|
||||
),
|
||||
|
||||
"speech_recognition": ServiceDefinition(
|
||||
id="speech_recognition",
|
||||
name="Speech Recognition",
|
||||
category=ServiceCategory.AI_ML,
|
||||
description="Transcribe audio to text using speech recognition models",
|
||||
icon="🎙️",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="model",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Speech recognition model",
|
||||
options=["whisper-tiny", "whisper-base", "whisper-small", "whisper-medium", "whisper-large", "whisper-large-v2", "whisper-large-v3"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="audio_file",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Audio file to transcribe"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="language",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Audio language",
|
||||
default="auto",
|
||||
options=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh", "ar", "hi"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="task",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Task type",
|
||||
default="transcribe",
|
||||
options=["transcribe", "translate"]
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"language": {"type": "string"},
|
||||
"segments": {"type": "array"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3060"),
|
||||
HardwareRequirement(component="vram", min_value=1, recommended=4, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01)
|
||||
],
|
||||
capabilities=["transcribe", "translate", "timestamp", "speaker-diarization"],
|
||||
tags=["speech", "audio", "transcription", "whisper"],
|
||||
max_concurrent=2,
|
||||
timeout_seconds=600
|
||||
),
|
||||
|
||||
"computer_vision": ServiceDefinition(
|
||||
id="computer_vision",
|
||||
name="Computer Vision",
|
||||
category=ServiceCategory.AI_ML,
|
||||
description="Analyze images with computer vision models",
|
||||
icon="👁️",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="task",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Vision task",
|
||||
options=["object-detection", "classification", "face-recognition", "segmentation", "ocr"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="model",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Vision model",
|
||||
options=["yolo-v8", "resnet-50", "efficientnet", "vit", "face-net", "tesseract"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="image",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Input image"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="confidence_threshold",
|
||||
type=ParameterType.FLOAT,
|
||||
required=False,
|
||||
description="Confidence threshold",
|
||||
default=0.5,
|
||||
min_value=0.0,
|
||||
max_value=1.0
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"detections": {"type": "array"},
|
||||
"labels": {"type": "array"},
|
||||
"confidence_scores": {"type": "array"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3060"),
|
||||
HardwareRequirement(component="vram", min_value=2, recommended=8, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_image", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.01)
|
||||
],
|
||||
capabilities=["detection", "classification", "recognition", "segmentation", "ocr"],
|
||||
tags=["vision", "image", "analysis", "ai", "detection"],
|
||||
max_concurrent=4,
|
||||
timeout_seconds=120
|
||||
),
|
||||
|
||||
"recommendation_system": ServiceDefinition(
|
||||
id="recommendation_system",
|
||||
name="Recommendation System",
|
||||
category=ServiceCategory.AI_ML,
|
||||
description="Generate personalized recommendations",
|
||||
icon="🎯",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="model_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Recommendation model type",
|
||||
options=["collaborative", "content-based", "hybrid", "deep-learning"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="user_id",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="User identifier"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="item_data",
|
||||
type=ParameterType.ARRAY,
|
||||
required=True,
|
||||
description="Item catalog data"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="num_recommendations",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of recommendations",
|
||||
default=10,
|
||||
min_value=1,
|
||||
max_value=100
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"recommendations": {"type": "array"},
|
||||
"scores": {"type": "array"},
|
||||
"explanation": {"type": "string"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=4, recommended=12, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_request", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.01),
|
||||
PricingTier(name="bulk", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.1)
|
||||
],
|
||||
capabilities=["personalization", "real-time", "batch", "ab-testing"],
|
||||
tags=["recommendation", "personalization", "ml", "ecommerce"],
|
||||
max_concurrent=10,
|
||||
timeout_seconds=60
|
||||
)
|
||||
}
|
||||
286
apps/coordinator-api/src/app/models/registry_data.py
Normal file
286
apps/coordinator-api/src/app/models/registry_data.py
Normal file
@ -0,0 +1,286 @@
|
||||
"""
|
||||
Data analytics service definitions
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Union
|
||||
from .registry import (
|
||||
ServiceDefinition,
|
||||
ServiceCategory,
|
||||
ParameterDefinition,
|
||||
ParameterType,
|
||||
HardwareRequirement,
|
||||
PricingTier,
|
||||
PricingModel
|
||||
)
|
||||
|
||||
|
||||
DATA_ANALYTICS_SERVICES = {
|
||||
"big_data_processing": ServiceDefinition(
|
||||
id="big_data_processing",
|
||||
name="Big Data Processing",
|
||||
category=ServiceCategory.DATA_ANALYTICS,
|
||||
description="GPU-accelerated ETL and data processing with RAPIDS",
|
||||
icon="📊",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="operation",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Processing operation",
|
||||
options=["etl", "aggregate", "join", "filter", "transform", "clean"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="data_source",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Data source URL or connection string"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="query",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="SQL or data processing query"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_format",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Output format",
|
||||
default="parquet",
|
||||
options=["parquet", "csv", "json", "delta", "orc"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="partition_by",
|
||||
type=ParameterType.ARRAY,
|
||||
required=False,
|
||||
description="Partition columns",
|
||||
items={"type": "string"}
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_url": {"type": "string"},
|
||||
"row_count": {"type": "integer"},
|
||||
"columns": {"type": "array"},
|
||||
"processing_stats": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=32, recommended=128, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_gb", model=PricingModel.PER_GB, unit_price=0.01, min_charge=0.1),
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1),
|
||||
PricingTier(name="enterprise", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.5)
|
||||
],
|
||||
capabilities=["gpu-sql", "etl", "streaming", "distributed"],
|
||||
tags=["bigdata", "etl", "rapids", "spark", "sql"],
|
||||
max_concurrent=5,
|
||||
timeout_seconds=3600
|
||||
),
|
||||
|
||||
"real_time_analytics": ServiceDefinition(
|
||||
id="real_time_analytics",
|
||||
name="Real-time Analytics",
|
||||
category=ServiceCategory.DATA_ANALYTICS,
|
||||
description="Stream processing and real-time analytics with GPU acceleration",
|
||||
icon="⚡",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="stream_source",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Stream source (Kafka, Kinesis, etc.)"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="query",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Stream processing query"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="window_size",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Window size (e.g., 1m, 5m, 1h)",
|
||||
default="5m"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="aggregations",
|
||||
type=ParameterType.ARRAY,
|
||||
required=True,
|
||||
description="Aggregation functions",
|
||||
items={"type": "string"}
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_sink",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Output sink for results"
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stream_id": {"type": "string"},
|
||||
"throughput": {"type": "number"},
|
||||
"latency_ms": {"type": "integer"},
|
||||
"metrics": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
|
||||
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
|
||||
HardwareRequirement(component="network", min_value="10Gbps", recommended="100Gbps"),
|
||||
HardwareRequirement(component="ram", min_value=64, recommended=256, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=2),
|
||||
PricingTier(name="per_million_events", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=1),
|
||||
PricingTier(name="high_throughput", model=PricingModel.PER_HOUR, unit_price=5, min_charge=5)
|
||||
],
|
||||
capabilities=["streaming", "windowing", "aggregation", "cep"],
|
||||
tags=["streaming", "real-time", "analytics", "kafka", "flink"],
|
||||
max_concurrent=10,
|
||||
timeout_seconds=86400 # 24 hours
|
||||
),
|
||||
|
||||
"graph_analytics": ServiceDefinition(
|
||||
id="graph_analytics",
|
||||
name="Graph Analytics",
|
||||
category=ServiceCategory.DATA_ANALYTICS,
|
||||
description="Network analysis and graph algorithms on GPU",
|
||||
icon="🕸️",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="algorithm",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Graph algorithm",
|
||||
options=["pagerank", "community-detection", "shortest-path", "triangles", "clustering", "centrality"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="graph_data",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Graph data file (edges list, adjacency matrix, etc.)"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="graph_format",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Graph format",
|
||||
default="edges",
|
||||
options=["edges", "adjacency", "csr", "metis"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="parameters",
|
||||
type=ParameterType.OBJECT,
|
||||
required=False,
|
||||
description="Algorithm-specific parameters"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="num_vertices",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of vertices",
|
||||
min_value=1
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"results": {"type": "array"},
|
||||
"statistics": {"type": "object"},
|
||||
"graph_metrics": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3090"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=24, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=64, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_million_edges", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1),
|
||||
PricingTier(name="large_graph", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.5)
|
||||
],
|
||||
capabilities=["gpu-graph", "algorithms", "network-analysis", "fraud-detection"],
|
||||
tags=["graph", "network", "analytics", "pagerank", "fraud"],
|
||||
max_concurrent=5,
|
||||
timeout_seconds=3600
|
||||
),
|
||||
|
||||
"time_series_analysis": ServiceDefinition(
|
||||
id="time_series_analysis",
|
||||
name="Time Series Analysis",
|
||||
category=ServiceCategory.DATA_ANALYTICS,
|
||||
description="Analyze time series data with GPU-accelerated algorithms",
|
||||
icon="📈",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="analysis_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Analysis type",
|
||||
options=["forecasting", "anomaly-detection", "decomposition", "seasonality", "trend"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="time_series_data",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Time series data file"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="model",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Analysis model",
|
||||
options=["arima", "prophet", "lstm", "transformer", "holt-winters", "var"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="forecast_horizon",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Forecast horizon",
|
||||
default=30,
|
||||
min_value=1,
|
||||
max_value=365
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="frequency",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Data frequency (D, H, M, S)",
|
||||
default="D"
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"forecast": {"type": "array"},
|
||||
"confidence_intervals": {"type": "array"},
|
||||
"model_metrics": {"type": "object"},
|
||||
"anomalies": {"type": "array"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_1k_points", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
|
||||
PricingTier(name="per_forecast", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
|
||||
PricingTier(name="enterprise", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1)
|
||||
],
|
||||
capabilities=["forecasting", "anomaly-detection", "decomposition", "seasonality"],
|
||||
tags=["time-series", "forecasting", "anomaly", "arima", "lstm"],
|
||||
max_concurrent=10,
|
||||
timeout_seconds=1800
|
||||
)
|
||||
}
|
||||
408
apps/coordinator-api/src/app/models/registry_devtools.py
Normal file
408
apps/coordinator-api/src/app/models/registry_devtools.py
Normal file
@ -0,0 +1,408 @@
|
||||
"""
|
||||
Development tools service definitions
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Union
|
||||
from .registry import (
|
||||
ServiceDefinition,
|
||||
ServiceCategory,
|
||||
ParameterDefinition,
|
||||
ParameterType,
|
||||
HardwareRequirement,
|
||||
PricingTier,
|
||||
PricingModel
|
||||
)
|
||||
|
||||
|
||||
DEVTOOLS_SERVICES = {
|
||||
"gpu_compilation": ServiceDefinition(
|
||||
id="gpu_compilation",
|
||||
name="GPU-Accelerated Compilation",
|
||||
category=ServiceCategory.DEVELOPMENT_TOOLS,
|
||||
description="Compile code with GPU acceleration (CUDA, HIP, OpenCL)",
|
||||
icon="⚙️",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="language",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Programming language",
|
||||
options=["cpp", "cuda", "hip", "opencl", "metal", "sycl"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="source_files",
|
||||
type=ParameterType.ARRAY,
|
||||
required=True,
|
||||
description="Source code files",
|
||||
items={"type": "string"}
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="build_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Build type",
|
||||
default="release",
|
||||
options=["debug", "release", "relwithdebinfo"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="target_arch",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Target architecture",
|
||||
default="sm_70",
|
||||
options=["sm_60", "sm_70", "sm_80", "sm_86", "sm_89", "sm_90"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="optimization_level",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Optimization level",
|
||||
default="O2",
|
||||
options=["O0", "O1", "O2", "O3", "Os"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="parallel_jobs",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of parallel compilation jobs",
|
||||
default=4,
|
||||
min_value=1,
|
||||
max_value=64
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"binary_url": {"type": "string"},
|
||||
"build_log": {"type": "string"},
|
||||
"compilation_time": {"type": "number"},
|
||||
"binary_size": {"type": "integer"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=4, recommended=8, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB"),
|
||||
HardwareRequirement(component="cuda", min_value="11.8")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
|
||||
PricingTier(name="per_file", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
|
||||
PricingTier(name="enterprise", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1)
|
||||
],
|
||||
capabilities=["cuda", "hip", "parallel-compilation", "incremental"],
|
||||
tags=["compilation", "cuda", "gpu", "cpp", "build"],
|
||||
max_concurrent=5,
|
||||
timeout_seconds=1800
|
||||
),
|
||||
|
||||
"model_training": ServiceDefinition(
|
||||
id="model_training",
|
||||
name="ML Model Training",
|
||||
category=ServiceCategory.DEVELOPMENT_TOOLS,
|
||||
description="Fine-tune or train machine learning models on client data",
|
||||
icon="🧠",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="model_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Model type",
|
||||
options=["transformer", "cnn", "rnn", "gan", "diffusion", "custom"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="base_model",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Base model to fine-tune"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="training_data",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Training dataset"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="validation_data",
|
||||
type=ParameterType.FILE,
|
||||
required=False,
|
||||
description="Validation dataset"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="epochs",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of training epochs",
|
||||
default=10,
|
||||
min_value=1,
|
||||
max_value=1000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="batch_size",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Batch size",
|
||||
default=32,
|
||||
min_value=1,
|
||||
max_value=1024
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="learning_rate",
|
||||
type=ParameterType.FLOAT,
|
||||
required=False,
|
||||
description="Learning rate",
|
||||
default=0.001,
|
||||
min_value=0.00001,
|
||||
max_value=1
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="hyperparameters",
|
||||
type=ParameterType.OBJECT,
|
||||
required=False,
|
||||
description="Additional hyperparameters"
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model_url": {"type": "string"},
|
||||
"training_metrics": {"type": "object"},
|
||||
"loss_curves": {"type": "array"},
|
||||
"validation_scores": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
|
||||
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=16, recommended=32, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=32, recommended=128, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_epoch", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=1),
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=2),
|
||||
PricingTier(name="enterprise", model=PricingModel.PER_UNIT, unit_price=0.05, min_charge=0.5)
|
||||
],
|
||||
capabilities=["fine-tuning", "training", "hyperparameter-tuning", "distributed"],
|
||||
tags=["ml", "training", "fine-tuning", "pytorch", "tensorflow"],
|
||||
max_concurrent=2,
|
||||
timeout_seconds=86400 # 24 hours
|
||||
),
|
||||
|
||||
"data_processing": ServiceDefinition(
|
||||
id="data_processing",
|
||||
name="Large Dataset Processing",
|
||||
category=ServiceCategory.DEVELOPMENT_TOOLS,
|
||||
description="Preprocess and transform large datasets",
|
||||
icon="📦",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="operation",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Processing operation",
|
||||
options=["clean", "transform", "normalize", "augment", "split", "encode"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="input_data",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Input dataset"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_format",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Output format",
|
||||
default="parquet",
|
||||
options=["csv", "json", "parquet", "hdf5", "feather", "pickle"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="chunk_size",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Processing chunk size",
|
||||
default=10000,
|
||||
min_value=100,
|
||||
max_value=1000000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="parameters",
|
||||
type=ParameterType.OBJECT,
|
||||
required=False,
|
||||
description="Operation-specific parameters"
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_url": {"type": "string"},
|
||||
"processing_stats": {"type": "object"},
|
||||
"data_quality": {"type": "object"},
|
||||
"row_count": {"type": "integer"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
|
||||
HardwareRequirement(component="vram", min_value=4, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=64, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_gb", model=PricingModel.PER_GB, unit_price=0.01, min_charge=0.1),
|
||||
PricingTier(name="per_million_rows", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
|
||||
PricingTier(name="enterprise", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1)
|
||||
],
|
||||
capabilities=["gpu-processing", "parallel", "streaming", "validation"],
|
||||
tags=["data", "preprocessing", "etl", "cleaning", "transformation"],
|
||||
max_concurrent=5,
|
||||
timeout_seconds=3600
|
||||
),
|
||||
|
||||
"simulation_testing": ServiceDefinition(
|
||||
id="simulation_testing",
|
||||
name="Hardware-in-the-Loop Testing",
|
||||
category=ServiceCategory.DEVELOPMENT_TOOLS,
|
||||
description="Run hardware simulations and testing workflows",
|
||||
icon="🔬",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="test_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Test type",
|
||||
options=["hardware", "firmware", "software", "integration", "performance"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="test_suite",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Test suite configuration"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="hardware_config",
|
||||
type=ParameterType.OBJECT,
|
||||
required=True,
|
||||
description="Hardware configuration"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="duration",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Test duration in hours",
|
||||
default=1,
|
||||
min_value=0.1,
|
||||
max_value=168 # 1 week
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="parallel_tests",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of parallel tests",
|
||||
default=1,
|
||||
min_value=1,
|
||||
max_value=10
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"test_results": {"type": "array"},
|
||||
"performance_metrics": {"type": "object"},
|
||||
"failure_logs": {"type": "array"},
|
||||
"coverage_report": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
|
||||
HardwareRequirement(component="cpu", min_value=16, recommended=32, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=32, recommended=128, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=100, recommended=500, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=1),
|
||||
PricingTier(name="per_test", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=0.5),
|
||||
PricingTier(name="continuous", model=PricingModel.PER_HOUR, unit_price=5, min_charge=5)
|
||||
],
|
||||
capabilities=["hardware-simulation", "automated-testing", "performance", "debugging"],
|
||||
tags=["testing", "simulation", "hardware", "hil", "verification"],
|
||||
max_concurrent=3,
|
||||
timeout_seconds=604800 # 1 week
|
||||
),
|
||||
|
||||
"code_generation": ServiceDefinition(
|
||||
id="code_generation",
|
||||
name="AI Code Generation",
|
||||
category=ServiceCategory.DEVELOPMENT_TOOLS,
|
||||
description="Generate code from natural language descriptions",
|
||||
icon="💻",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="language",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Target programming language",
|
||||
options=["python", "javascript", "cpp", "java", "go", "rust", "typescript", "sql"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="description",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Natural language description of code to generate",
|
||||
max_value=2000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="framework",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Target framework or library"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="code_style",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Code style preferences",
|
||||
default="standard",
|
||||
options=["standard", "functional", "oop", "minimalist"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="include_comments",
|
||||
type=ParameterType.BOOLEAN,
|
||||
required=False,
|
||||
description="Include explanatory comments",
|
||||
default=True
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="include_tests",
|
||||
type=ParameterType.BOOLEAN,
|
||||
required=False,
|
||||
description="Generate unit tests",
|
||||
default=False
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"generated_code": {"type": "string"},
|
||||
"explanation": {"type": "string"},
|
||||
"usage_example": {"type": "string"},
|
||||
"test_code": {"type": "string"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=8, recommended=16, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_generation", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.01),
|
||||
PricingTier(name="per_100_lines", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
|
||||
PricingTier(name="with_tests", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.02)
|
||||
],
|
||||
capabilities=["code-gen", "documentation", "test-gen", "refactoring"],
|
||||
tags=["code", "generation", "ai", "copilot", "automation"],
|
||||
max_concurrent=10,
|
||||
timeout_seconds=120
|
||||
)
|
||||
}
|
||||
307
apps/coordinator-api/src/app/models/registry_gaming.py
Normal file
307
apps/coordinator-api/src/app/models/registry_gaming.py
Normal file
@ -0,0 +1,307 @@
|
||||
"""
|
||||
Gaming & entertainment service definitions
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Union
|
||||
from .registry import (
|
||||
ServiceDefinition,
|
||||
ServiceCategory,
|
||||
ParameterDefinition,
|
||||
ParameterType,
|
||||
HardwareRequirement,
|
||||
PricingTier,
|
||||
PricingModel
|
||||
)
|
||||
|
||||
|
||||
GAMING_SERVICES = {
|
||||
"cloud_gaming": ServiceDefinition(
|
||||
id="cloud_gaming",
|
||||
name="Cloud Gaming Server",
|
||||
category=ServiceCategory.GAMING_ENTERTAINMENT,
|
||||
description="Host cloud gaming sessions with GPU streaming",
|
||||
icon="🎮",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="game",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Game title or executable"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="resolution",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Streaming resolution",
|
||||
options=["720p", "1080p", "1440p", "4k"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="fps",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Target frame rate",
|
||||
default=60,
|
||||
options=[30, 60, 120, 144]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="session_duration",
|
||||
type=ParameterType.INTEGER,
|
||||
required=True,
|
||||
description="Session duration in minutes",
|
||||
min_value=15,
|
||||
max_value=480
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="codec",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Streaming codec",
|
||||
default="h264",
|
||||
options=["h264", "h265", "av1", "vp9"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="region",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Preferred server region"
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stream_url": {"type": "string"},
|
||||
"session_id": {"type": "string"},
|
||||
"latency_ms": {"type": "integer"},
|
||||
"quality_metrics": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="network", min_value="100Mbps", recommended="1Gbps"),
|
||||
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=0.5),
|
||||
PricingTier(name="1080p", model=PricingModel.PER_HOUR, unit_price=1.5, min_charge=0.75),
|
||||
PricingTier(name="4k", model=PricingModel.PER_HOUR, unit_price=3, min_charge=1.5)
|
||||
],
|
||||
capabilities=["low-latency", "game-streaming", "multiplayer", "saves"],
|
||||
tags=["gaming", "cloud", "streaming", "nvidia", "gamepass"],
|
||||
max_concurrent=1,
|
||||
timeout_seconds=28800 # 8 hours
|
||||
),
|
||||
|
||||
"game_asset_baking": ServiceDefinition(
|
||||
id="game_asset_baking",
|
||||
name="Game Asset Baking",
|
||||
category=ServiceCategory.GAMING_ENTERTAINMENT,
|
||||
description="Optimize and bake game assets (textures, meshes, materials)",
|
||||
icon="🎨",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="asset_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Asset type",
|
||||
options=["texture", "mesh", "material", "animation", "terrain"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="input_assets",
|
||||
type=ParameterType.ARRAY,
|
||||
required=True,
|
||||
description="Input asset files",
|
||||
items={"type": "string"}
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="target_platform",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Target platform",
|
||||
options=["pc", "mobile", "console", "web", "vr"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="optimization_level",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Optimization level",
|
||||
default="balanced",
|
||||
options=["fast", "balanced", "maximum"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="texture_formats",
|
||||
type=ParameterType.ARRAY,
|
||||
required=False,
|
||||
description="Output texture formats",
|
||||
default=["dds", "astc"],
|
||||
items={"type": "string"}
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"baked_assets": {"type": "array"},
|
||||
"compression_stats": {"type": "object"},
|
||||
"optimization_report": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=50, recommended=500, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_asset", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
|
||||
PricingTier(name="per_texture", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.05),
|
||||
PricingTier(name="per_mesh", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.1)
|
||||
],
|
||||
capabilities=["texture-compression", "mesh-optimization", "lod-generation", "platform-specific"],
|
||||
tags=["gamedev", "assets", "optimization", "textures", "meshes"],
|
||||
max_concurrent=5,
|
||||
timeout_seconds=1800
|
||||
),
|
||||
|
||||
"physics_simulation": ServiceDefinition(
|
||||
id="physics_simulation",
|
||||
name="Game Physics Simulation",
|
||||
category=ServiceCategory.GAMING_ENTERTAINMENT,
|
||||
description="Run physics simulations for game development",
|
||||
icon="⚛️",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="engine",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Physics engine",
|
||||
options=["physx", "havok", "bullet", "box2d", "chipmunk"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="simulation_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Simulation type",
|
||||
options=["rigid-body", "soft-body", "fluid", "cloth", "destruction"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="scene_file",
|
||||
type=ParameterType.FILE,
|
||||
required=False,
|
||||
description="Scene or level file"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="parameters",
|
||||
type=ParameterType.OBJECT,
|
||||
required=True,
|
||||
description="Physics parameters"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="simulation_time",
|
||||
type=ParameterType.FLOAT,
|
||||
required=True,
|
||||
description="Simulation duration in seconds",
|
||||
min_value=0.1
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="record_frames",
|
||||
type=ParameterType.BOOLEAN,
|
||||
required=False,
|
||||
description="Record animation frames",
|
||||
default=False
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"simulation_data": {"type": "array"},
|
||||
"animation_url": {"type": "string"},
|
||||
"physics_stats": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=0.5),
|
||||
PricingTier(name="per_frame", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.1),
|
||||
PricingTier(name="complex", model=PricingModel.PER_HOUR, unit_price=2, min_charge=1)
|
||||
],
|
||||
capabilities=["gpu-physics", "particle-systems", "destruction", "cloth"],
|
||||
tags=["physics", "gamedev", "simulation", "physx", "havok"],
|
||||
max_concurrent=3,
|
||||
timeout_seconds=3600
|
||||
),
|
||||
|
||||
"vr_ar_rendering": ServiceDefinition(
|
||||
id="vr_ar_rendering",
|
||||
name="VR/AR Rendering",
|
||||
category=ServiceCategory.GAMING_ENTERTAINMENT,
|
||||
description="Real-time 3D rendering for VR/AR applications",
|
||||
icon="🥽",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="platform",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Target platform",
|
||||
options=["oculus", "vive", "hololens", "magic-leap", "cardboard", "webxr"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="scene_file",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="3D scene file"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="render_quality",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Render quality",
|
||||
default="high",
|
||||
options=["low", "medium", "high", "ultra"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="stereo_mode",
|
||||
type=ParameterType.BOOLEAN,
|
||||
required=False,
|
||||
description="Stereo rendering",
|
||||
default=True
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="target_fps",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Target frame rate",
|
||||
default=90,
|
||||
options=[60, 72, 90, 120, 144]
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"rendered_frames": {"type": "array"},
|
||||
"performance_metrics": {"type": "object"},
|
||||
"vr_package": {"type": "string"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.5),
|
||||
PricingTier(name="per_frame", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.1),
|
||||
PricingTier(name="real-time", model=PricingModel.PER_HOUR, unit_price=5, min_charge=1)
|
||||
],
|
||||
capabilities=["stereo-rendering", "real-time", "low-latency", "tracking"],
|
||||
tags=["vr", "ar", "rendering", "3d", "immersive"],
|
||||
max_concurrent=2,
|
||||
timeout_seconds=3600
|
||||
)
|
||||
}
|
||||
412
apps/coordinator-api/src/app/models/registry_media.py
Normal file
412
apps/coordinator-api/src/app/models/registry_media.py
Normal file
@ -0,0 +1,412 @@
|
||||
"""
|
||||
Media processing service definitions
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Union
|
||||
from .registry import (
|
||||
ServiceDefinition,
|
||||
ServiceCategory,
|
||||
ParameterDefinition,
|
||||
ParameterType,
|
||||
HardwareRequirement,
|
||||
PricingTier,
|
||||
PricingModel
|
||||
)
|
||||
|
||||
|
||||
MEDIA_PROCESSING_SERVICES = {
|
||||
"video_transcoding": ServiceDefinition(
|
||||
id="video_transcoding",
|
||||
name="Video Transcoding",
|
||||
category=ServiceCategory.MEDIA_PROCESSING,
|
||||
description="Transcode videos between formats using FFmpeg with GPU acceleration",
|
||||
icon="🎬",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="input_video",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Input video file"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_format",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Output video format",
|
||||
options=["mp4", "webm", "avi", "mov", "mkv", "flv"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="codec",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Video codec",
|
||||
default="h264",
|
||||
options=["h264", "h265", "vp9", "av1", "mpeg4"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="resolution",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Output resolution (e.g., 1920x1080)",
|
||||
validation={"pattern": r"^\d+x\d+$"}
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="bitrate",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Target bitrate (e.g., 5M, 2500k)",
|
||||
validation={"pattern": r"^\d+[kM]?$"}
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="fps",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Output frame rate",
|
||||
min_value=1,
|
||||
max_value=120
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="gpu_acceleration",
|
||||
type=ParameterType.BOOLEAN,
|
||||
required=False,
|
||||
description="Use GPU acceleration",
|
||||
default=True
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_url": {"type": "string"},
|
||||
"metadata": {"type": "object"},
|
||||
"duration": {"type": "number"},
|
||||
"file_size": {"type": "integer"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
|
||||
HardwareRequirement(component="vram", min_value=2, recommended=8, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=50, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.01),
|
||||
PricingTier(name="per_gb", model=PricingModel.PER_GB, unit_price=0.01, min_charge=0.01),
|
||||
PricingTier(name="4k_premium", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.05)
|
||||
],
|
||||
capabilities=["transcode", "compress", "resize", "format-convert"],
|
||||
tags=["video", "ffmpeg", "transcoding", "encoding", "gpu"],
|
||||
max_concurrent=2,
|
||||
timeout_seconds=3600
|
||||
),
|
||||
|
||||
"video_streaming": ServiceDefinition(
|
||||
id="video_streaming",
|
||||
name="Live Video Streaming",
|
||||
category=ServiceCategory.MEDIA_PROCESSING,
|
||||
description="Real-time video transcoding for adaptive bitrate streaming",
|
||||
icon="📡",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="stream_url",
|
||||
type=ParameterType.STRING,
|
||||
required=True,
|
||||
description="Input stream URL"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_formats",
|
||||
type=ParameterType.ARRAY,
|
||||
required=True,
|
||||
description="Output formats for adaptive streaming",
|
||||
default=["720p", "1080p", "4k"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="duration_minutes",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Streaming duration in minutes",
|
||||
default=60,
|
||||
min_value=1,
|
||||
max_value=480
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="protocol",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Streaming protocol",
|
||||
default="hls",
|
||||
options=["hls", "dash", "rtmp", "webrtc"]
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stream_url": {"type": "string"},
|
||||
"playlist_url": {"type": "string"},
|
||||
"bitrates": {"type": "array"},
|
||||
"duration": {"type": "number"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="network", min_value="1Gbps", recommended="10Gbps"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.5),
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=0.5, min_charge=0.5)
|
||||
],
|
||||
capabilities=["live-transcoding", "adaptive-bitrate", "multi-format", "low-latency"],
|
||||
tags=["streaming", "live", "transcoding", "real-time"],
|
||||
max_concurrent=5,
|
||||
timeout_seconds=28800 # 8 hours
|
||||
),
|
||||
|
||||
"3d_rendering": ServiceDefinition(
|
||||
id="3d_rendering",
|
||||
name="3D Rendering",
|
||||
category=ServiceCategory.MEDIA_PROCESSING,
|
||||
description="Render 3D scenes using Blender, Unreal Engine, or V-Ray",
|
||||
icon="🎭",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="engine",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Rendering engine",
|
||||
options=["blender-cycles", "blender-eevee", "unreal-engine", "v-ray", "octane"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="scene_file",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="3D scene file (.blend, .ueproject, etc)"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="resolution_x",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Output width",
|
||||
default=1920,
|
||||
min_value=1,
|
||||
max_value=8192
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="resolution_y",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Output height",
|
||||
default=1080,
|
||||
min_value=1,
|
||||
max_value=8192
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="samples",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Samples per pixel (path tracing)",
|
||||
default=128,
|
||||
min_value=1,
|
||||
max_value=10000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="frame_start",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Start frame for animation",
|
||||
default=1,
|
||||
min_value=1
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="frame_end",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="End frame for animation",
|
||||
default=1,
|
||||
min_value=1
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_format",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Output image format",
|
||||
default="png",
|
||||
options=["png", "jpg", "exr", "bmp", "tiff", "hdr"]
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"rendered_images": {"type": "array"},
|
||||
"metadata": {"type": "object"},
|
||||
"render_time": {"type": "number"},
|
||||
"frame_count": {"type": "integer"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-4090"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=24, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=64, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_frame", model=PricingModel.PER_FRAME, unit_price=0.01, min_charge=0.1),
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=0.5, min_charge=0.5),
|
||||
PricingTier(name="4k_premium", model=PricingModel.PER_FRAME, unit_price=0.05, min_charge=0.5)
|
||||
],
|
||||
capabilities=["path-tracing", "ray-tracing", "animation", "gpu-render"],
|
||||
tags=["3d", "rendering", "blender", "unreal", "v-ray"],
|
||||
max_concurrent=2,
|
||||
timeout_seconds=7200
|
||||
),
|
||||
|
||||
"image_processing": ServiceDefinition(
|
||||
id="image_processing",
|
||||
name="Batch Image Processing",
|
||||
category=ServiceCategory.MEDIA_PROCESSING,
|
||||
description="Process images in bulk with filters, effects, and format conversion",
|
||||
icon="🖼️",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="images",
|
||||
type=ParameterType.ARRAY,
|
||||
required=True,
|
||||
description="Array of image files or URLs"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="operations",
|
||||
type=ParameterType.ARRAY,
|
||||
required=True,
|
||||
description="Processing operations to apply",
|
||||
items={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"params": {"type": "object"}
|
||||
}
|
||||
}
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_format",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Output format",
|
||||
default="jpg",
|
||||
options=["jpg", "png", "webp", "avif", "tiff", "bmp"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="quality",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Output quality (1-100)",
|
||||
default=90,
|
||||
min_value=1,
|
||||
max_value=100
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="resize",
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
description="Resize dimensions (e.g., 1920x1080, 50%)",
|
||||
validation={"pattern": r"^\d+x\d+|^\d+%$"}
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"processed_images": {"type": "array"},
|
||||
"count": {"type": "integer"},
|
||||
"total_size": {"type": "integer"},
|
||||
"processing_time": {"type": "number"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
|
||||
HardwareRequirement(component="vram", min_value=1, recommended=4, unit="GB"),
|
||||
HardwareRequirement(component="ram", min_value=4, recommended=16, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_image", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
|
||||
PricingTier(name="bulk_100", model=PricingModel.PER_UNIT, unit_price=0.0005, min_charge=0.05),
|
||||
PricingTier(name="bulk_1000", model=PricingModel.PER_UNIT, unit_price=0.0002, min_charge=0.2)
|
||||
],
|
||||
capabilities=["resize", "filter", "format-convert", "batch", "watermark"],
|
||||
tags=["image", "processing", "batch", "filter", "conversion"],
|
||||
max_concurrent=10,
|
||||
timeout_seconds=600
|
||||
),
|
||||
|
||||
"audio_processing": ServiceDefinition(
|
||||
id="audio_processing",
|
||||
name="Audio Processing",
|
||||
category=ServiceCategory.MEDIA_PROCESSING,
|
||||
description="Process audio files with effects, noise reduction, and format conversion",
|
||||
icon="🎵",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="audio_file",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Input audio file"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="operations",
|
||||
type=ParameterType.ARRAY,
|
||||
required=True,
|
||||
description="Audio operations to apply",
|
||||
items={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"params": {"type": "object"}
|
||||
}
|
||||
}
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_format",
|
||||
type=ParameterType.ENUM,
|
||||
required=False,
|
||||
description="Output format",
|
||||
default="mp3",
|
||||
options=["mp3", "wav", "flac", "aac", "ogg", "m4a"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="sample_rate",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Output sample rate",
|
||||
default=44100,
|
||||
options=[22050, 44100, 48000, 96000, 192000]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="bitrate",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Output bitrate (kbps)",
|
||||
default=320,
|
||||
options=[128, 192, 256, 320, 512, 1024]
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_url": {"type": "string"},
|
||||
"metadata": {"type": "object"},
|
||||
"duration": {"type": "number"},
|
||||
"file_size": {"type": "integer"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
|
||||
HardwareRequirement(component="ram", min_value=2, recommended=8, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.002, min_charge=0.01),
|
||||
PricingTier(name="per_effect", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.01)
|
||||
],
|
||||
capabilities=["noise-reduction", "effects", "format-convert", "enhancement"],
|
||||
tags=["audio", "processing", "effects", "noise-reduction"],
|
||||
max_concurrent=5,
|
||||
timeout_seconds=300
|
||||
)
|
||||
}
|
||||
406
apps/coordinator-api/src/app/models/registry_scientific.py
Normal file
406
apps/coordinator-api/src/app/models/registry_scientific.py
Normal file
@ -0,0 +1,406 @@
|
||||
"""
|
||||
Scientific computing service definitions
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Union
|
||||
from .registry import (
|
||||
ServiceDefinition,
|
||||
ServiceCategory,
|
||||
ParameterDefinition,
|
||||
ParameterType,
|
||||
HardwareRequirement,
|
||||
PricingTier,
|
||||
PricingModel
|
||||
)
|
||||
|
||||
|
||||
SCIENTIFIC_COMPUTING_SERVICES = {
|
||||
"molecular_dynamics": ServiceDefinition(
|
||||
id="molecular_dynamics",
|
||||
name="Molecular Dynamics Simulation",
|
||||
category=ServiceCategory.SCIENTIFIC_COMPUTING,
|
||||
description="Run molecular dynamics simulations using GROMACS or NAMD",
|
||||
icon="🧬",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="software",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="MD software package",
|
||||
options=["gromacs", "namd", "amber", "lammps", "desmond"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="structure_file",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Molecular structure file (PDB, MOL2, etc)"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="topology_file",
|
||||
type=ParameterType.FILE,
|
||||
required=False,
|
||||
description="Topology file"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="force_field",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Force field to use",
|
||||
options=["AMBER", "CHARMM", "OPLS", "GROMOS", "DREIDING"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="simulation_time_ns",
|
||||
type=ParameterType.FLOAT,
|
||||
required=True,
|
||||
description="Simulation time in nanoseconds",
|
||||
min_value=0.1,
|
||||
max_value=1000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="temperature_k",
|
||||
type=ParameterType.FLOAT,
|
||||
required=False,
|
||||
description="Temperature in Kelvin",
|
||||
default=300,
|
||||
min_value=0,
|
||||
max_value=500
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="pressure_bar",
|
||||
type=ParameterType.FLOAT,
|
||||
required=False,
|
||||
description="Pressure in bar",
|
||||
default=1,
|
||||
min_value=0,
|
||||
max_value=1000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="time_step_fs",
|
||||
type=ParameterType.FLOAT,
|
||||
required=False,
|
||||
description="Time step in femtoseconds",
|
||||
default=2,
|
||||
min_value=0.5,
|
||||
max_value=5
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"trajectory_url": {"type": "string"},
|
||||
"log_url": {"type": "string"},
|
||||
"energy_data": {"type": "array"},
|
||||
"simulation_stats": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
|
||||
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=16, recommended=64, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=32, recommended=256, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_ns", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=1),
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=2),
|
||||
PricingTier(name="bulk_100ns", model=PricingModel.PER_UNIT, unit_price=0.05, min_charge=5)
|
||||
],
|
||||
capabilities=["gpu-accelerated", "parallel", "ensemble", "free-energy"],
|
||||
tags=["molecular", "dynamics", "simulation", "biophysics", "chemistry"],
|
||||
max_concurrent=4,
|
||||
timeout_seconds=86400 # 24 hours
|
||||
),
|
||||
|
||||
"weather_modeling": ServiceDefinition(
|
||||
id="weather_modeling",
|
||||
name="Weather Modeling",
|
||||
category=ServiceCategory.SCIENTIFIC_COMPUTING,
|
||||
description="Run weather prediction and climate simulations",
|
||||
icon="🌦️",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="model",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Weather model",
|
||||
options=["WRF", "MM5", "IFS", "GFS", "ECMWF"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="region",
|
||||
type=ParameterType.OBJECT,
|
||||
required=True,
|
||||
description="Geographic region bounds",
|
||||
properties={
|
||||
"lat_min": {"type": "number"},
|
||||
"lat_max": {"type": "number"},
|
||||
"lon_min": {"type": "number"},
|
||||
"lon_max": {"type": "number"}
|
||||
}
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="forecast_hours",
|
||||
type=ParameterType.INTEGER,
|
||||
required=True,
|
||||
description="Forecast length in hours",
|
||||
min_value=1,
|
||||
max_value=384 # 16 days
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="resolution_km",
|
||||
type=ParameterType.FLOAT,
|
||||
required=False,
|
||||
description="Spatial resolution in kilometers",
|
||||
default=10,
|
||||
options=[1, 3, 5, 10, 25, 50]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="output_variables",
|
||||
type=ParameterType.ARRAY,
|
||||
required=False,
|
||||
description="Variables to output",
|
||||
default=["temperature", "precipitation", "wind", "pressure"],
|
||||
items={"type": "string"}
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"forecast_data": {"type": "array"},
|
||||
"visualization_urls": {"type": "array"},
|
||||
"metadata": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="cpu", min_value=32, recommended=128, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=64, recommended=512, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=500, recommended=5000, unit="GB"),
|
||||
HardwareRequirement(component="network", min_value="10Gbps", recommended="100Gbps")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=5, min_charge=10),
|
||||
PricingTier(name="per_day", model=PricingModel.PER_UNIT, unit_price=100, min_charge=100),
|
||||
PricingTier(name="high_res", model=PricingModel.PER_HOUR, unit_price=10, min_charge=20)
|
||||
],
|
||||
capabilities=["forecast", "climate", "ensemble", "data-assimilation"],
|
||||
tags=["weather", "climate", "forecast", "meteorology", "atmosphere"],
|
||||
max_concurrent=2,
|
||||
timeout_seconds=172800 # 48 hours
|
||||
),
|
||||
|
||||
"financial_modeling": ServiceDefinition(
|
||||
id="financial_modeling",
|
||||
name="Financial Modeling",
|
||||
category=ServiceCategory.SCIENTIFIC_COMPUTING,
|
||||
description="Run Monte Carlo simulations and risk analysis for financial models",
|
||||
icon="📊",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="model_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Financial model type",
|
||||
options=["monte-carlo", "option-pricing", "risk-var", "portfolio-optimization", "credit-risk"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="parameters",
|
||||
type=ParameterType.OBJECT,
|
||||
required=True,
|
||||
description="Model parameters"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="num_simulations",
|
||||
type=ParameterType.INTEGER,
|
||||
required=True,
|
||||
description="Number of Monte Carlo simulations",
|
||||
default=10000,
|
||||
min_value=1000,
|
||||
max_value=10000000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="time_steps",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of time steps",
|
||||
default=252,
|
||||
min_value=1,
|
||||
max_value=10000
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="confidence_levels",
|
||||
type=ParameterType.ARRAY,
|
||||
required=False,
|
||||
description="Confidence levels for VaR",
|
||||
default=[0.95, 0.99],
|
||||
items={"type": "number", "minimum": 0, "maximum": 1}
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"results": {"type": "array"},
|
||||
"statistics": {"type": "object"},
|
||||
"risk_metrics": {"type": "object"},
|
||||
"confidence_intervals": {"type": "array"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=8, recommended=32, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=16, recommended=64, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_simulation", model=PricingModel.PER_UNIT, unit_price=0.00001, min_charge=0.1),
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1),
|
||||
PricingTier(name="enterprise", model=PricingModel.PER_UNIT, unit_price=0.000005, min_charge=0.5)
|
||||
],
|
||||
capabilities=["monte-carlo", "var", "option-pricing", "portfolio", "risk-analysis"],
|
||||
tags=["finance", "risk", "monte-carlo", "var", "options"],
|
||||
max_concurrent=10,
|
||||
timeout_seconds=3600
|
||||
),
|
||||
|
||||
"physics_simulation": ServiceDefinition(
|
||||
id="physics_simulation",
|
||||
name="Physics Simulation",
|
||||
category=ServiceCategory.SCIENTIFIC_COMPUTING,
|
||||
description="Run particle physics and fluid dynamics simulations",
|
||||
icon="⚛️",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="simulation_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Physics simulation type",
|
||||
options=["particle-physics", "fluid-dynamics", "electromagnetics", "quantum", "astrophysics"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="solver",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Simulation solver",
|
||||
options=["geant4", "fluent", "comsol", "openfoam", "lammps", "gadget"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="geometry_file",
|
||||
type=ParameterType.FILE,
|
||||
required=False,
|
||||
description="Geometry or mesh file"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="initial_conditions",
|
||||
type=ParameterType.OBJECT,
|
||||
required=True,
|
||||
description="Initial conditions and parameters"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="simulation_time",
|
||||
type=ParameterType.FLOAT,
|
||||
required=True,
|
||||
description="Simulation time",
|
||||
min_value=0.001
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="particles",
|
||||
type=ParameterType.INTEGER,
|
||||
required=False,
|
||||
description="Number of particles",
|
||||
default=1000000,
|
||||
min_value=1000,
|
||||
max_value=100000000
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"results_url": {"type": "string"},
|
||||
"data_arrays": {"type": "object"},
|
||||
"visualizations": {"type": "array"},
|
||||
"statistics": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
|
||||
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=16, recommended=64, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=32, recommended=256, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=2),
|
||||
PricingTier(name="per_particle", model=PricingModel.PER_UNIT, unit_price=0.000001, min_charge=1),
|
||||
PricingTier(name="hpc", model=PricingModel.PER_HOUR, unit_price=5, min_charge=5)
|
||||
],
|
||||
capabilities=["gpu-accelerated", "parallel", "mpi", "large-scale"],
|
||||
tags=["physics", "simulation", "particle", "fluid", "cfd"],
|
||||
max_concurrent=4,
|
||||
timeout_seconds=86400
|
||||
),
|
||||
|
||||
"bioinformatics": ServiceDefinition(
|
||||
id="bioinformatics",
|
||||
name="Bioinformatics Analysis",
|
||||
category=ServiceCategory.SCIENTIFIC_COMPUTING,
|
||||
description="DNA sequencing, protein folding, and genomic analysis",
|
||||
icon="🧬",
|
||||
input_parameters=[
|
||||
ParameterDefinition(
|
||||
name="analysis_type",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Bioinformatics analysis type",
|
||||
options=["dna-sequencing", "protein-folding", "alignment", "phylogeny", "variant-calling"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="sequence_file",
|
||||
type=ParameterType.FILE,
|
||||
required=True,
|
||||
description="Input sequence file (FASTA, FASTQ, BAM, etc)"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="reference_file",
|
||||
type=ParameterType.FILE,
|
||||
required=False,
|
||||
description="Reference genome or protein structure"
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="algorithm",
|
||||
type=ParameterType.ENUM,
|
||||
required=True,
|
||||
description="Analysis algorithm",
|
||||
options=["blast", "bowtie", "bwa", "alphafold", "gatk", "clustal"]
|
||||
),
|
||||
ParameterDefinition(
|
||||
name="parameters",
|
||||
type=ParameterType.OBJECT,
|
||||
required=False,
|
||||
description="Algorithm-specific parameters"
|
||||
)
|
||||
],
|
||||
output_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"results_file": {"type": "string"},
|
||||
"alignment_file": {"type": "string"},
|
||||
"annotations": {"type": "array"},
|
||||
"statistics": {"type": "object"}
|
||||
}
|
||||
},
|
||||
requirements=[
|
||||
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3090"),
|
||||
HardwareRequirement(component="vram", min_value=8, recommended=24, unit="GB"),
|
||||
HardwareRequirement(component="cpu", min_value=16, recommended=32, unit="cores"),
|
||||
HardwareRequirement(component="ram", min_value=32, recommended=128, unit="GB"),
|
||||
HardwareRequirement(component="storage", min_value=100, recommended=500, unit="GB")
|
||||
],
|
||||
pricing=[
|
||||
PricingTier(name="per_mb", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.1),
|
||||
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1),
|
||||
PricingTier(name="protein_folding", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.5)
|
||||
],
|
||||
capabilities=["sequencing", "alignment", "folding", "annotation", "variant-calling"],
|
||||
tags=["bioinformatics", "genomics", "proteomics", "dna", "sequencing"],
|
||||
max_concurrent=5,
|
||||
timeout_seconds=7200
|
||||
)
|
||||
}
|
||||
380
apps/coordinator-api/src/app/models/services.py
Normal file
380
apps/coordinator-api/src/app/models/services.py
Normal file
@ -0,0 +1,380 @@
|
||||
"""
|
||||
Service schemas for common GPU workloads
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import re
|
||||
|
||||
|
||||
class ServiceType(str, Enum):
|
||||
"""Supported service types"""
|
||||
WHISPER = "whisper"
|
||||
STABLE_DIFFUSION = "stable_diffusion"
|
||||
LLM_INFERENCE = "llm_inference"
|
||||
FFMPEG = "ffmpeg"
|
||||
BLENDER = "blender"
|
||||
|
||||
|
||||
# Whisper Service Schemas
|
||||
class WhisperModel(str, Enum):
|
||||
"""Supported Whisper models"""
|
||||
TINY = "tiny"
|
||||
BASE = "base"
|
||||
SMALL = "small"
|
||||
MEDIUM = "medium"
|
||||
LARGE = "large"
|
||||
LARGE_V2 = "large-v2"
|
||||
LARGE_V3 = "large-v3"
|
||||
|
||||
|
||||
class WhisperLanguage(str, Enum):
|
||||
"""Supported languages"""
|
||||
AUTO = "auto"
|
||||
EN = "en"
|
||||
ES = "es"
|
||||
FR = "fr"
|
||||
DE = "de"
|
||||
IT = "it"
|
||||
PT = "pt"
|
||||
RU = "ru"
|
||||
JA = "ja"
|
||||
KO = "ko"
|
||||
ZH = "zh"
|
||||
|
||||
|
||||
class WhisperTask(str, Enum):
|
||||
"""Whisper task types"""
|
||||
TRANSCRIBE = "transcribe"
|
||||
TRANSLATE = "translate"
|
||||
|
||||
|
||||
class WhisperRequest(BaseModel):
|
||||
"""Whisper transcription request"""
|
||||
audio_url: str = Field(..., description="URL of audio file to transcribe")
|
||||
model: WhisperModel = Field(WhisperModel.BASE, description="Whisper model to use")
|
||||
language: WhisperLanguage = Field(WhisperLanguage.AUTO, description="Source language")
|
||||
task: WhisperTask = Field(WhisperTask.TRANSCRIBE, description="Task to perform")
|
||||
temperature: float = Field(0.0, ge=0.0, le=1.0, description="Sampling temperature")
|
||||
best_of: int = Field(5, ge=1, le=10, description="Number of candidates")
|
||||
beam_size: int = Field(5, ge=1, le=10, description="Beam size for decoding")
|
||||
patience: float = Field(1.0, ge=0.0, le=2.0, description="Beam search patience")
|
||||
suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress")
|
||||
initial_prompt: Optional[str] = Field(None, description="Initial prompt for context")
|
||||
condition_on_previous_text: bool = Field(True, description="Condition on previous text")
|
||||
fp16: bool = Field(True, description="Use FP16 for faster inference")
|
||||
verbose: bool = Field(False, description="Include verbose output")
|
||||
|
||||
def get_constraints(self) -> Dict[str, Any]:
|
||||
"""Get hardware constraints for this request"""
|
||||
vram_requirements = {
|
||||
WhisperModel.TINY: 1,
|
||||
WhisperModel.BASE: 1,
|
||||
WhisperModel.SMALL: 2,
|
||||
WhisperModel.MEDIUM: 5,
|
||||
WhisperModel.LARGE: 10,
|
||||
WhisperModel.LARGE_V2: 10,
|
||||
WhisperModel.LARGE_V3: 10,
|
||||
}
|
||||
|
||||
return {
|
||||
"models": ["whisper"],
|
||||
"min_vram_gb": vram_requirements[self.model],
|
||||
"gpu": "nvidia", # Whisper requires CUDA
|
||||
}
|
||||
|
||||
|
||||
# Stable Diffusion Service Schemas
|
||||
class SDModel(str, Enum):
|
||||
"""Supported Stable Diffusion models"""
|
||||
SD_1_5 = "stable-diffusion-1.5"
|
||||
SD_2_1 = "stable-diffusion-2.1"
|
||||
SDXL = "stable-diffusion-xl"
|
||||
SDXL_TURBO = "sdxl-turbo"
|
||||
SDXL_REFINER = "sdxl-refiner"
|
||||
|
||||
|
||||
class SDSize(str, Enum):
|
||||
"""Standard image sizes"""
|
||||
SQUARE_512 = "512x512"
|
||||
PORTRAIT_512 = "512x768"
|
||||
LANDSCAPE_512 = "768x512"
|
||||
SQUARE_768 = "768x768"
|
||||
PORTRAIT_768 = "768x1024"
|
||||
LANDSCAPE_768 = "1024x768"
|
||||
SQUARE_1024 = "1024x1024"
|
||||
PORTRAIT_1024 = "1024x1536"
|
||||
LANDSCAPE_1024 = "1536x1024"
|
||||
|
||||
|
||||
class StableDiffusionRequest(BaseModel):
|
||||
"""Stable Diffusion image generation request"""
|
||||
prompt: str = Field(..., min_length=1, max_length=1000, description="Text prompt")
|
||||
negative_prompt: Optional[str] = Field(None, max_length=1000, description="Negative prompt")
|
||||
model: SDModel = Field(SD_1_5, description="Model to use")
|
||||
size: SDSize = Field(SDSize.SQUARE_512, description="Image size")
|
||||
num_images: int = Field(1, ge=1, le=4, description="Number of images to generate")
|
||||
num_inference_steps: int = Field(20, ge=1, le=100, description="Number of inference steps")
|
||||
guidance_scale: float = Field(7.5, ge=1.0, le=20.0, description="Guidance scale")
|
||||
seed: Optional[Union[int, List[int]]] = Field(None, description="Random seed(s)")
|
||||
scheduler: str = Field("DPMSolverMultistepScheduler", description="Scheduler to use")
|
||||
enable_safety_checker: bool = Field(True, description="Enable safety checker")
|
||||
lora: Optional[str] = Field(None, description="LoRA model to use")
|
||||
lora_scale: float = Field(1.0, ge=0.0, le=2.0, description="LoRA strength")
|
||||
|
||||
@validator('seed')
|
||||
def validate_seed(cls, v):
|
||||
if v is not None and isinstance(v, list):
|
||||
if len(v) > 4:
|
||||
raise ValueError("Maximum 4 seeds allowed")
|
||||
return v
|
||||
|
||||
def get_constraints(self) -> Dict[str, Any]:
|
||||
"""Get hardware constraints for this request"""
|
||||
vram_requirements = {
|
||||
SDModel.SD_1_5: 4,
|
||||
SDModel.SD_2_1: 4,
|
||||
SDModel.SDXL: 8,
|
||||
SDModel.SDXL_TURBO: 8,
|
||||
SDModel.SDXL_REFINER: 8,
|
||||
}
|
||||
|
||||
size_map = {
|
||||
"512": 512,
|
||||
"768": 768,
|
||||
"1024": 1024,
|
||||
"1536": 1536,
|
||||
}
|
||||
|
||||
# Extract max dimension from size
|
||||
max_dim = max(size_map[s.split('x')[0]] for s in SDSize)
|
||||
|
||||
return {
|
||||
"models": ["stable-diffusion"],
|
||||
"min_vram_gb": vram_requirements[self.model],
|
||||
"gpu": "nvidia", # SD requires CUDA
|
||||
"cuda": "11.8", # Minimum CUDA version
|
||||
}
|
||||
|
||||
|
||||
# LLM Inference Service Schemas
|
||||
class LLMModel(str, Enum):
|
||||
"""Supported LLM models"""
|
||||
LLAMA_7B = "llama-7b"
|
||||
LLAMA_13B = "llama-13b"
|
||||
LLAMA_70B = "llama-70b"
|
||||
MISTRAL_7B = "mistral-7b"
|
||||
MIXTRAL_8X7B = "mixtral-8x7b"
|
||||
CODELLAMA_7B = "codellama-7b"
|
||||
CODELLAMA_13B = "codellama-13b"
|
||||
CODELLAMA_34B = "codellama-34b"
|
||||
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
"""LLM inference request"""
|
||||
model: LLMModel = Field(..., description="Model to use")
|
||||
prompt: str = Field(..., min_length=1, max_length=10000, description="Input prompt")
|
||||
max_tokens: int = Field(256, ge=1, le=4096, description="Maximum tokens to generate")
|
||||
temperature: float = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
|
||||
top_p: float = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling")
|
||||
top_k: int = Field(40, ge=0, le=100, description="Top-k sampling")
|
||||
repetition_penalty: float = Field(1.1, ge=0.0, le=2.0, description="Repetition penalty")
|
||||
stop_sequences: Optional[List[str]] = Field(None, description="Stop sequences")
|
||||
stream: bool = Field(False, description="Stream response")
|
||||
|
||||
def get_constraints(self) -> Dict[str, Any]:
|
||||
"""Get hardware constraints for this request"""
|
||||
vram_requirements = {
|
||||
LLMModel.LLAMA_7B: 8,
|
||||
LLMModel.LLAMA_13B: 16,
|
||||
LLMModel.LLAMA_70B: 64,
|
||||
LLMModel.MISTRAL_7B: 8,
|
||||
LLMModel.MIXTRAL_8X7B: 48,
|
||||
LLMModel.CODELLAMA_7B: 8,
|
||||
LLMModel.CODELLAMA_13B: 16,
|
||||
LLMModel.CODELLAMA_34B: 32,
|
||||
}
|
||||
|
||||
return {
|
||||
"models": ["llm"],
|
||||
"min_vram_gb": vram_requirements[self.model],
|
||||
"gpu": "nvidia", # LLMs require CUDA
|
||||
"cuda": "11.8",
|
||||
}
|
||||
|
||||
|
||||
# FFmpeg Service Schemas
|
||||
class FFmpegCodec(str, Enum):
|
||||
"""Supported video codecs"""
|
||||
H264 = "h264"
|
||||
H265 = "h265"
|
||||
VP9 = "vp9"
|
||||
AV1 = "av1"
|
||||
|
||||
|
||||
class FFmpegPreset(str, Enum):
|
||||
"""Encoding presets"""
|
||||
ULTRAFAST = "ultrafast"
|
||||
SUPERFAST = "superfast"
|
||||
VERYFAST = "veryfast"
|
||||
FASTER = "faster"
|
||||
FAST = "fast"
|
||||
MEDIUM = "medium"
|
||||
SLOW = "slow"
|
||||
SLOWER = "slower"
|
||||
VERYSLOW = "veryslow"
|
||||
|
||||
|
||||
class FFmpegRequest(BaseModel):
|
||||
"""FFmpeg video processing request"""
|
||||
input_url: str = Field(..., description="URL of input video")
|
||||
output_format: str = Field("mp4", description="Output format")
|
||||
codec: FFmpegCodec = Field(FFmpegCodec.H264, description="Video codec")
|
||||
preset: FFmpegPreset = Field(FFmpegPreset.MEDIUM, description="Encoding preset")
|
||||
crf: int = Field(23, ge=0, le=51, description="Constant rate factor")
|
||||
resolution: Optional[str] = Field(None, regex=r"^\d+x\d+$", description="Output resolution (e.g., 1920x1080)")
|
||||
bitrate: Optional[str] = Field(None, regex=r"^\d+[kM]?$", description="Target bitrate")
|
||||
fps: Optional[int] = Field(None, ge=1, le=120, description="Output frame rate")
|
||||
audio_codec: str = Field("aac", description="Audio codec")
|
||||
audio_bitrate: str = Field("128k", description="Audio bitrate")
|
||||
custom_args: Optional[List[str]] = Field(None, description="Custom FFmpeg arguments")
|
||||
|
||||
def get_constraints(self) -> Dict[str, Any]:
|
||||
"""Get hardware constraints for this request"""
|
||||
# NVENC support for H.264/H.265
|
||||
if self.codec in [FFmpegCodec.H264, FFmpegCodec.H265]:
|
||||
return {
|
||||
"models": ["ffmpeg"],
|
||||
"gpu": "nvidia", # NVENC requires NVIDIA
|
||||
"min_vram_gb": 4,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"models": ["ffmpeg"],
|
||||
"gpu": "any", # CPU encoding possible
|
||||
}
|
||||
|
||||
|
||||
# Blender Service Schemas
|
||||
class BlenderEngine(str, Enum):
|
||||
"""Blender render engines"""
|
||||
CYCLES = "cycles"
|
||||
EEVEE = "eevee"
|
||||
EEVEE_NEXT = "eevee-next"
|
||||
|
||||
|
||||
class BlenderFormat(str, Enum):
|
||||
"""Output formats"""
|
||||
PNG = "png"
|
||||
JPG = "jpg"
|
||||
EXR = "exr"
|
||||
BMP = "bmp"
|
||||
TIFF = "tiff"
|
||||
|
||||
|
||||
class BlenderRequest(BaseModel):
|
||||
"""Blender rendering request"""
|
||||
blend_file_url: str = Field(..., description="URL of .blend file")
|
||||
engine: BlenderEngine = Field(BlenderEngine.CYCLES, description="Render engine")
|
||||
format: BlenderFormat = Field(BlenderFormat.PNG, description="Output format")
|
||||
resolution_x: int = Field(1920, ge=1, le=65536, description="Image width")
|
||||
resolution_y: int = Field(1080, ge=1, le=65536, description="Image height")
|
||||
resolution_percentage: int = Field(100, ge=1, le=100, description="Resolution scale")
|
||||
samples: int = Field(128, ge=1, le=10000, description="Samples (Cycles only)")
|
||||
frame_start: int = Field(1, ge=1, description="Start frame")
|
||||
frame_end: int = Field(1, ge=1, description="End frame")
|
||||
frame_step: int = Field(1, ge=1, description="Frame step")
|
||||
denoise: bool = Field(True, description="Enable denoising")
|
||||
transparent: bool = Field(False, description="Transparent background")
|
||||
custom_args: Optional[List[str]] = Field(None, description="Custom Blender arguments")
|
||||
|
||||
@validator('frame_end')
|
||||
def validate_frame_range(cls, v, values):
|
||||
if 'frame_start' in values and v < values['frame_start']:
|
||||
raise ValueError("frame_end must be >= frame_start")
|
||||
return v
|
||||
|
||||
def get_constraints(self) -> Dict[str, Any]:
|
||||
"""Get hardware constraints for this request"""
|
||||
# Calculate VRAM based on resolution and samples
|
||||
pixel_count = self.resolution_x * self.resolution_y
|
||||
samples_multiplier = 1 if self.engine == BlenderEngine.EEVEE else self.samples / 100
|
||||
|
||||
estimated_vram = int((pixel_count * samples_multiplier) / (1024 * 1024))
|
||||
|
||||
return {
|
||||
"models": ["blender"],
|
||||
"min_vram_gb": max(4, estimated_vram),
|
||||
"gpu": "nvidia" if self.engine == BlenderEngine.CYCLES else "any",
|
||||
}
|
||||
|
||||
|
||||
# Unified Service Request
|
||||
class ServiceRequest(BaseModel):
|
||||
"""Unified service request wrapper"""
|
||||
service_type: ServiceType = Field(..., description="Type of service")
|
||||
request_data: Dict[str, Any] = Field(..., description="Service-specific request data")
|
||||
|
||||
def get_service_request(self) -> Union[
|
||||
WhisperRequest,
|
||||
StableDiffusionRequest,
|
||||
LLMRequest,
|
||||
FFmpegRequest,
|
||||
BlenderRequest
|
||||
]:
|
||||
"""Parse and return typed service request"""
|
||||
service_classes = {
|
||||
ServiceType.WHISPER: WhisperRequest,
|
||||
ServiceType.STABLE_DIFFUSION: StableDiffusionRequest,
|
||||
ServiceType.LLM_INFERENCE: LLMRequest,
|
||||
ServiceType.FFMPEG: FFmpegRequest,
|
||||
ServiceType.BLENDER: BlenderRequest,
|
||||
}
|
||||
|
||||
service_class = service_classes[self.service_type]
|
||||
return service_class(**self.request_data)
|
||||
|
||||
|
||||
# Service Response Schemas
|
||||
class ServiceResponse(BaseModel):
|
||||
"""Base service response"""
|
||||
job_id: str = Field(..., description="Job ID")
|
||||
service_type: ServiceType = Field(..., description="Service type")
|
||||
status: str = Field(..., description="Job status")
|
||||
estimated_completion: Optional[str] = Field(None, description="Estimated completion time")
|
||||
|
||||
|
||||
class WhisperResponse(BaseModel):
|
||||
"""Whisper transcription response"""
|
||||
text: str = Field(..., description="Transcribed text")
|
||||
language: str = Field(..., description="Detected language")
|
||||
segments: Optional[List[Dict[str, Any]]] = Field(None, description="Transcription segments")
|
||||
|
||||
|
||||
class StableDiffusionResponse(BaseModel):
|
||||
"""Stable Diffusion image generation response"""
|
||||
images: List[str] = Field(..., description="Generated image URLs")
|
||||
parameters: Dict[str, Any] = Field(..., description="Generation parameters")
|
||||
nsfw_content_detected: List[bool] = Field(..., description="NSFW detection results")
|
||||
|
||||
|
||||
class LLMResponse(BaseModel):
|
||||
"""LLM inference response"""
|
||||
text: str = Field(..., description="Generated text")
|
||||
finish_reason: str = Field(..., description="Reason for generation stop")
|
||||
tokens_used: int = Field(..., description="Number of tokens used")
|
||||
|
||||
|
||||
class FFmpegResponse(BaseModel):
|
||||
"""FFmpeg processing response"""
|
||||
output_url: str = Field(..., description="URL of processed video")
|
||||
metadata: Dict[str, Any] = Field(..., description="Video metadata")
|
||||
duration: float = Field(..., description="Video duration")
|
||||
|
||||
|
||||
class BlenderResponse(BaseModel):
|
||||
"""Blender rendering response"""
|
||||
images: List[str] = Field(..., description="Rendered image URLs")
|
||||
metadata: Dict[str, Any] = Field(..., description="Render metadata")
|
||||
render_time: float = Field(..., description="Render time in seconds")
|
||||
428
apps/coordinator-api/src/app/repositories/confidential.py
Normal file
428
apps/coordinator-api/src/app/repositories/confidential.py
Normal file
@ -0,0 +1,428 @@
|
||||
"""
|
||||
Repository layer for confidential transactions
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
import json
|
||||
from base64 import b64encode, b64decode
|
||||
|
||||
from sqlalchemy import select, update, delete, and_, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from ..models.confidential import (
|
||||
ConfidentialTransactionDB,
|
||||
ParticipantKeyDB,
|
||||
ConfidentialAccessLogDB,
|
||||
KeyRotationLogDB,
|
||||
AuditAuthorizationDB
|
||||
)
|
||||
from ..models import (
|
||||
ConfidentialTransaction,
|
||||
KeyPair,
|
||||
ConfidentialAccessLog,
|
||||
KeyRotationLog,
|
||||
AuditAuthorization
|
||||
)
|
||||
from ..database import get_async_session
|
||||
|
||||
|
||||
class ConfidentialTransactionRepository:
|
||||
"""Repository for confidential transaction operations"""
|
||||
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transaction: ConfidentialTransaction
|
||||
) -> ConfidentialTransactionDB:
|
||||
"""Create a new confidential transaction"""
|
||||
db_transaction = ConfidentialTransactionDB(
|
||||
transaction_id=transaction.transaction_id,
|
||||
job_id=transaction.job_id,
|
||||
status=transaction.status,
|
||||
confidential=transaction.confidential,
|
||||
algorithm=transaction.algorithm,
|
||||
encrypted_data=b64decode(transaction.encrypted_data) if transaction.encrypted_data else None,
|
||||
encrypted_keys=transaction.encrypted_keys,
|
||||
participants=transaction.participants,
|
||||
access_policies=transaction.access_policies,
|
||||
created_by=transaction.participants[0] if transaction.participants else None
|
||||
)
|
||||
|
||||
session.add(db_transaction)
|
||||
await session.commit()
|
||||
await session.refresh(db_transaction)
|
||||
|
||||
return db_transaction
|
||||
|
||||
async def get_by_id(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transaction_id: str
|
||||
) -> Optional[ConfidentialTransactionDB]:
|
||||
"""Get transaction by ID"""
|
||||
stmt = select(ConfidentialTransactionDB).where(
|
||||
ConfidentialTransactionDB.transaction_id == transaction_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_job_id(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
job_id: str
|
||||
) -> Optional[ConfidentialTransactionDB]:
|
||||
"""Get transaction by job ID"""
|
||||
stmt = select(ConfidentialTransactionDB).where(
|
||||
ConfidentialTransactionDB.job_id == job_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_participant(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
participant_id: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[ConfidentialTransactionDB]:
|
||||
"""List transactions for a participant"""
|
||||
stmt = select(ConfidentialTransactionDB).where(
|
||||
ConfidentialTransactionDB.participants.contains([participant_id])
|
||||
).offset(offset).limit(limit)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transaction_id: str,
|
||||
status: str
|
||||
) -> bool:
|
||||
"""Update transaction status"""
|
||||
stmt = update(ConfidentialTransactionDB).where(
|
||||
ConfidentialTransactionDB.transaction_id == transaction_id
|
||||
).values(status=status)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transaction_id: str
|
||||
) -> bool:
|
||||
"""Delete a transaction"""
|
||||
stmt = delete(ConfidentialTransactionDB).where(
|
||||
ConfidentialTransactionDB.transaction_id == transaction_id
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
class ParticipantKeyRepository:
|
||||
"""Repository for participant key operations"""
|
||||
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
key_pair: KeyPair
|
||||
) -> ParticipantKeyDB:
|
||||
"""Store a new key pair"""
|
||||
# In production, private_key should be encrypted with master key
|
||||
db_key = ParticipantKeyDB(
|
||||
participant_id=key_pair.participant_id,
|
||||
encrypted_private_key=key_pair.private_key,
|
||||
public_key=key_pair.public_key,
|
||||
algorithm=key_pair.algorithm,
|
||||
version=key_pair.version,
|
||||
active=True
|
||||
)
|
||||
|
||||
session.add(db_key)
|
||||
await session.commit()
|
||||
await session.refresh(db_key)
|
||||
|
||||
return db_key
|
||||
|
||||
async def get_by_participant(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
participant_id: str,
|
||||
active_only: bool = True
|
||||
) -> Optional[ParticipantKeyDB]:
|
||||
"""Get key pair for participant"""
|
||||
stmt = select(ParticipantKeyDB).where(
|
||||
ParticipantKeyDB.participant_id == participant_id
|
||||
)
|
||||
|
||||
if active_only:
|
||||
stmt = stmt.where(ParticipantKeyDB.active == True)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_active(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
participant_id: str,
|
||||
active: bool,
|
||||
reason: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Update key active status"""
|
||||
stmt = update(ParticipantKeyDB).where(
|
||||
ParticipantKeyDB.participant_id == participant_id
|
||||
).values(
|
||||
active=active,
|
||||
revoked_at=datetime.utcnow() if not active else None,
|
||||
revoke_reason=reason
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
async def rotate(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
participant_id: str,
|
||||
new_key_pair: KeyPair
|
||||
) -> ParticipantKeyDB:
|
||||
"""Rotate to new key pair"""
|
||||
# Deactivate old key
|
||||
await self.update_active(session, participant_id, False, "rotation")
|
||||
|
||||
# Store new key
|
||||
return await self.create(session, new_key_pair)
|
||||
|
||||
async def list_active(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[ParticipantKeyDB]:
|
||||
"""List active keys"""
|
||||
stmt = select(ParticipantKeyDB).where(
|
||||
ParticipantKeyDB.active == True
|
||||
).offset(offset).limit(limit)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
class AccessLogRepository:
|
||||
"""Repository for access log operations"""
|
||||
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
log: ConfidentialAccessLog
|
||||
) -> ConfidentialAccessLogDB:
|
||||
"""Create access log entry"""
|
||||
db_log = ConfidentialAccessLogDB(
|
||||
transaction_id=log.transaction_id,
|
||||
participant_id=log.participant_id,
|
||||
purpose=log.purpose,
|
||||
action=log.action,
|
||||
resource=log.resource,
|
||||
outcome=log.outcome,
|
||||
details=log.details,
|
||||
data_accessed=log.data_accessed,
|
||||
ip_address=log.ip_address,
|
||||
user_agent=log.user_agent,
|
||||
authorization_id=log.authorized_by,
|
||||
signature=log.signature
|
||||
)
|
||||
|
||||
session.add(db_log)
|
||||
await session.commit()
|
||||
await session.refresh(db_log)
|
||||
|
||||
return db_log
|
||||
|
||||
async def query(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transaction_id: Optional[str] = None,
|
||||
participant_id: Optional[str] = None,
|
||||
purpose: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[ConfidentialAccessLogDB]:
|
||||
"""Query access logs"""
|
||||
stmt = select(ConfidentialAccessLogDB)
|
||||
|
||||
# Build filters
|
||||
filters = []
|
||||
if transaction_id:
|
||||
filters.append(ConfidentialAccessLogDB.transaction_id == transaction_id)
|
||||
if participant_id:
|
||||
filters.append(ConfidentialAccessLogDB.participant_id == participant_id)
|
||||
if purpose:
|
||||
filters.append(ConfidentialAccessLogDB.purpose == purpose)
|
||||
if start_time:
|
||||
filters.append(ConfidentialAccessLogDB.timestamp >= start_time)
|
||||
if end_time:
|
||||
filters.append(ConfidentialAccessLogDB.timestamp <= end_time)
|
||||
|
||||
if filters:
|
||||
stmt = stmt.where(and_(*filters))
|
||||
|
||||
# Order by timestamp descending
|
||||
stmt = stmt.order_by(ConfidentialAccessLogDB.timestamp.desc())
|
||||
stmt = stmt.offset(offset).limit(limit)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def count(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
transaction_id: Optional[str] = None,
|
||||
participant_id: Optional[str] = None,
|
||||
purpose: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None
|
||||
) -> int:
|
||||
"""Count access logs matching criteria"""
|
||||
stmt = select(ConfidentialAccessLogDB)
|
||||
|
||||
# Build filters
|
||||
filters = []
|
||||
if transaction_id:
|
||||
filters.append(ConfidentialAccessLogDB.transaction_id == transaction_id)
|
||||
if participant_id:
|
||||
filters.append(ConfidentialAccessLogDB.participant_id == participant_id)
|
||||
if purpose:
|
||||
filters.append(ConfidentialAccessLogDB.purpose == purpose)
|
||||
if start_time:
|
||||
filters.append(ConfidentialAccessLogDB.timestamp >= start_time)
|
||||
if end_time:
|
||||
filters.append(ConfidentialAccessLogDB.timestamp <= end_time)
|
||||
|
||||
if filters:
|
||||
stmt = stmt.where(and_(*filters))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return len(result.all())
|
||||
|
||||
|
||||
class KeyRotationRepository:
|
||||
"""Repository for key rotation logs"""
|
||||
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
log: KeyRotationLog
|
||||
) -> KeyRotationLogDB:
|
||||
"""Create key rotation log"""
|
||||
db_log = KeyRotationLogDB(
|
||||
participant_id=log.participant_id,
|
||||
old_version=log.old_version,
|
||||
new_version=log.new_version,
|
||||
rotated_at=log.rotated_at,
|
||||
reason=log.reason
|
||||
)
|
||||
|
||||
session.add(db_log)
|
||||
await session.commit()
|
||||
await session.refresh(db_log)
|
||||
|
||||
return db_log
|
||||
|
||||
async def list_by_participant(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
participant_id: str,
|
||||
limit: int = 50
|
||||
) -> List[KeyRotationLogDB]:
|
||||
"""List rotation logs for participant"""
|
||||
stmt = select(KeyRotationLogDB).where(
|
||||
KeyRotationLogDB.participant_id == participant_id
|
||||
).order_by(KeyRotationLogDB.rotated_at.desc()).limit(limit)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
class AuditAuthorizationRepository:
|
||||
"""Repository for audit authorizations"""
|
||||
|
||||
async def create(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
auth: AuditAuthorization
|
||||
) -> AuditAuthorizationDB:
|
||||
"""Create audit authorization"""
|
||||
db_auth = AuditAuthorizationDB(
|
||||
issuer=auth.issuer,
|
||||
subject=auth.subject,
|
||||
purpose=auth.purpose,
|
||||
created_at=auth.created_at,
|
||||
expires_at=auth.expires_at,
|
||||
signature=auth.signature,
|
||||
metadata=auth.__dict__
|
||||
)
|
||||
|
||||
session.add(db_auth)
|
||||
await session.commit()
|
||||
await session.refresh(db_auth)
|
||||
|
||||
return db_auth
|
||||
|
||||
async def get_valid(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
authorization_id: str
|
||||
) -> Optional[AuditAuthorizationDB]:
|
||||
"""Get valid authorization"""
|
||||
stmt = select(AuditAuthorizationDB).where(
|
||||
and_(
|
||||
AuditAuthorizationDB.id == authorization_id,
|
||||
AuditAuthorizationDB.active == True,
|
||||
AuditAuthorizationDB.expires_at > datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def revoke(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
authorization_id: str
|
||||
) -> bool:
|
||||
"""Revoke authorization"""
|
||||
stmt = update(AuditAuthorizationDB).where(
|
||||
AuditAuthorizationDB.id == authorization_id
|
||||
).values(active=False, revoked_at=datetime.utcnow())
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
async def cleanup_expired(
|
||||
self,
|
||||
session: AsyncSession
|
||||
) -> int:
|
||||
"""Clean up expired authorizations"""
|
||||
stmt = update(AuditAuthorizationDB).where(
|
||||
AuditAuthorizationDB.expires_at < datetime.utcnow()
|
||||
).values(active=False)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
return result.rowcount
|
||||
@ -5,5 +5,7 @@ from .miner import router as miner
|
||||
from .admin import router as admin
|
||||
from .marketplace import router as marketplace
|
||||
from .explorer import router as explorer
|
||||
from .services import router as services
|
||||
from .registry import router as registry
|
||||
|
||||
__all__ = ["client", "miner", "admin", "marketplace", "explorer"]
|
||||
__all__ = ["client", "miner", "admin", "marketplace", "explorer", "services", "registry"]
|
||||
|
||||
423
apps/coordinator-api/src/app/routers/confidential.py
Normal file
423
apps/coordinator-api/src/app/routers/confidential.py
Normal file
@ -0,0 +1,423 @@
|
||||
"""
|
||||
API endpoints for confidential transactions
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import json
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..models import (
|
||||
ConfidentialTransaction,
|
||||
ConfidentialTransactionCreate,
|
||||
ConfidentialTransactionView,
|
||||
ConfidentialAccessRequest,
|
||||
ConfidentialAccessResponse,
|
||||
KeyRegistrationRequest,
|
||||
KeyRegistrationResponse,
|
||||
AccessLogQuery,
|
||||
AccessLogResponse
|
||||
)
|
||||
from ..services.encryption import EncryptionService, EncryptedData
|
||||
from ..services.key_management import KeyManager, KeyManagementError
|
||||
from ..services.access_control import AccessController
|
||||
from ..auth import get_api_key
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Initialize router and security
|
||||
router = APIRouter(prefix="/confidential", tags=["confidential"])
|
||||
security = HTTPBearer()
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Global instances (in production, inject via DI)
|
||||
encryption_service: Optional[EncryptionService] = None
|
||||
key_manager: Optional[KeyManager] = None
|
||||
access_controller: Optional[AccessController] = None
|
||||
|
||||
|
||||
def get_encryption_service() -> EncryptionService:
|
||||
"""Get encryption service instance"""
|
||||
global encryption_service
|
||||
if encryption_service is None:
|
||||
# Initialize with key manager
|
||||
from ..services.key_management import FileKeyStorage
|
||||
key_storage = FileKeyStorage("/tmp/aitbc_keys")
|
||||
key_manager = KeyManager(key_storage)
|
||||
encryption_service = EncryptionService(key_manager)
|
||||
return encryption_service
|
||||
|
||||
|
||||
def get_key_manager() -> KeyManager:
|
||||
"""Get key manager instance"""
|
||||
global key_manager
|
||||
if key_manager is None:
|
||||
from ..services.key_management import FileKeyStorage
|
||||
key_storage = FileKeyStorage("/tmp/aitbc_keys")
|
||||
key_manager = KeyManager(key_storage)
|
||||
return key_manager
|
||||
|
||||
|
||||
def get_access_controller() -> AccessController:
|
||||
"""Get access controller instance"""
|
||||
global access_controller
|
||||
if access_controller is None:
|
||||
from ..services.access_control import PolicyStore
|
||||
policy_store = PolicyStore()
|
||||
access_controller = AccessController(policy_store)
|
||||
return access_controller
|
||||
|
||||
|
||||
@router.post("/transactions", response_model=ConfidentialTransactionView)
|
||||
async def create_confidential_transaction(
|
||||
request: ConfidentialTransactionCreate,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Create a new confidential transaction with optional encryption"""
|
||||
try:
|
||||
# Generate transaction ID
|
||||
transaction_id = f"ctx-{datetime.utcnow().timestamp()}"
|
||||
|
||||
# Create base transaction
|
||||
transaction = ConfidentialTransaction(
|
||||
transaction_id=transaction_id,
|
||||
job_id=request.job_id,
|
||||
timestamp=datetime.utcnow(),
|
||||
status="created",
|
||||
amount=request.amount,
|
||||
pricing=request.pricing,
|
||||
settlement_details=request.settlement_details,
|
||||
confidential=request.confidential,
|
||||
participants=request.participants,
|
||||
access_policies=request.access_policies
|
||||
)
|
||||
|
||||
# Encrypt sensitive data if requested
|
||||
if request.confidential and request.participants:
|
||||
# Prepare data for encryption
|
||||
sensitive_data = {
|
||||
"amount": request.amount,
|
||||
"pricing": request.pricing,
|
||||
"settlement_details": request.settlement_details
|
||||
}
|
||||
|
||||
# Remove None values
|
||||
sensitive_data = {k: v for k, v in sensitive_data.items() if v is not None}
|
||||
|
||||
if sensitive_data:
|
||||
# Encrypt data
|
||||
enc_service = get_encryption_service()
|
||||
encrypted = enc_service.encrypt(
|
||||
data=sensitive_data,
|
||||
participants=request.participants,
|
||||
include_audit=True
|
||||
)
|
||||
|
||||
# Update transaction with encrypted data
|
||||
transaction.encrypted_data = encrypted.to_dict()["ciphertext"]
|
||||
transaction.encrypted_keys = encrypted.to_dict()["encrypted_keys"]
|
||||
transaction.algorithm = encrypted.algorithm
|
||||
|
||||
# Clear plaintext fields
|
||||
transaction.amount = None
|
||||
transaction.pricing = None
|
||||
transaction.settlement_details = None
|
||||
|
||||
# Store transaction (in production, save to database)
|
||||
logger.info(f"Created confidential transaction: {transaction_id}")
|
||||
|
||||
# Return view
|
||||
return ConfidentialTransactionView(
|
||||
transaction_id=transaction.transaction_id,
|
||||
job_id=transaction.job_id,
|
||||
timestamp=transaction.timestamp,
|
||||
status=transaction.status,
|
||||
amount=transaction.amount, # Will be None if encrypted
|
||||
pricing=transaction.pricing,
|
||||
settlement_details=transaction.settlement_details,
|
||||
confidential=transaction.confidential,
|
||||
participants=transaction.participants,
|
||||
has_encrypted_data=transaction.encrypted_data is not None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create confidential transaction: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/transactions/{transaction_id}", response_model=ConfidentialTransactionView)
|
||||
async def get_confidential_transaction(
|
||||
transaction_id: str,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Get confidential transaction metadata (without decrypting sensitive data)"""
|
||||
try:
|
||||
# Retrieve transaction (in production, query from database)
|
||||
# For now, return error as we don't have storage
|
||||
raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transaction {transaction_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/transactions/{transaction_id}/access", response_model=ConfidentialAccessResponse)
|
||||
@limiter.limit("10/minute") # Rate limit decryption requests
|
||||
async def access_confidential_data(
|
||||
request: ConfidentialAccessRequest,
|
||||
transaction_id: str,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Request access to decrypt confidential transaction data"""
|
||||
try:
|
||||
# Validate request
|
||||
if request.transaction_id != transaction_id:
|
||||
raise HTTPException(status_code=400, detail="Transaction ID mismatch")
|
||||
|
||||
# Get transaction (in production, query from database)
|
||||
# For now, create mock transaction
|
||||
transaction = ConfidentialTransaction(
|
||||
transaction_id=transaction_id,
|
||||
job_id="test-job",
|
||||
timestamp=datetime.utcnow(),
|
||||
status="completed",
|
||||
confidential=True,
|
||||
participants=["client-456", "miner-789"]
|
||||
)
|
||||
|
||||
if not transaction.confidential:
|
||||
raise HTTPException(status_code=400, detail="Transaction is not confidential")
|
||||
|
||||
# Check access authorization
|
||||
acc_controller = get_access_controller()
|
||||
if not acc_controller.verify_access(request):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Decrypt data
|
||||
enc_service = get_encryption_service()
|
||||
|
||||
# Reconstruct encrypted data
|
||||
if not transaction.encrypted_data or not transaction.encrypted_keys:
|
||||
raise HTTPException(status_code=404, detail="No encrypted data found")
|
||||
|
||||
encrypted_data = EncryptedData.from_dict({
|
||||
"ciphertext": transaction.encrypted_data,
|
||||
"encrypted_keys": transaction.encrypted_keys,
|
||||
"algorithm": transaction.algorithm or "AES-256-GCM+X25519"
|
||||
})
|
||||
|
||||
# Decrypt for requester
|
||||
try:
|
||||
decrypted_data = enc_service.decrypt(
|
||||
encrypted_data=encrypted_data,
|
||||
participant_id=request.requester,
|
||||
purpose=request.purpose
|
||||
)
|
||||
|
||||
return ConfidentialAccessResponse(
|
||||
success=True,
|
||||
data=decrypted_data,
|
||||
access_id=f"access-{datetime.utcnow().timestamp()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Decryption failed: {e}")
|
||||
return ConfidentialAccessResponse(
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access confidential data: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/transactions/{transaction_id}/audit", response_model=ConfidentialAccessResponse)
|
||||
async def audit_access_confidential_data(
|
||||
transaction_id: str,
|
||||
authorization: str,
|
||||
purpose: str = "compliance",
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Audit access to confidential transaction data"""
|
||||
try:
|
||||
# Get transaction
|
||||
transaction = ConfidentialTransaction(
|
||||
transaction_id=transaction_id,
|
||||
job_id="test-job",
|
||||
timestamp=datetime.utcnow(),
|
||||
status="completed",
|
||||
confidential=True
|
||||
)
|
||||
|
||||
if not transaction.confidential:
|
||||
raise HTTPException(status_code=400, detail="Transaction is not confidential")
|
||||
|
||||
# Decrypt with audit key
|
||||
enc_service = get_encryption_service()
|
||||
|
||||
if not transaction.encrypted_data or not transaction.encrypted_keys:
|
||||
raise HTTPException(status_code=404, detail="No encrypted data found")
|
||||
|
||||
encrypted_data = EncryptedData.from_dict({
|
||||
"ciphertext": transaction.encrypted_data,
|
||||
"encrypted_keys": transaction.encrypted_keys,
|
||||
"algorithm": transaction.algorithm or "AES-256-GCM+X25519"
|
||||
})
|
||||
|
||||
# Decrypt for audit
|
||||
try:
|
||||
decrypted_data = enc_service.audit_decrypt(
|
||||
encrypted_data=encrypted_data,
|
||||
audit_authorization=authorization,
|
||||
purpose=purpose
|
||||
)
|
||||
|
||||
return ConfidentialAccessResponse(
|
||||
success=True,
|
||||
data=decrypted_data,
|
||||
access_id=f"audit-{datetime.utcnow().timestamp()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Audit decryption failed: {e}")
|
||||
return ConfidentialAccessResponse(
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed audit access: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/keys/register", response_model=KeyRegistrationResponse)
|
||||
async def register_encryption_key(
|
||||
request: KeyRegistrationRequest,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Register public key for confidential transactions"""
|
||||
try:
|
||||
# Get key manager
|
||||
km = get_key_manager()
|
||||
|
||||
# Check if participant already has keys
|
||||
try:
|
||||
existing_key = km.get_public_key(request.participant_id)
|
||||
if existing_key:
|
||||
# Key exists, return version
|
||||
return KeyRegistrationResponse(
|
||||
success=True,
|
||||
participant_id=request.participant_id,
|
||||
key_version=1, # Would get from storage
|
||||
registered_at=datetime.utcnow(),
|
||||
error=None
|
||||
)
|
||||
except:
|
||||
pass # Key doesn't exist, continue
|
||||
|
||||
# Generate new key pair
|
||||
key_pair = await km.generate_key_pair(request.participant_id)
|
||||
|
||||
return KeyRegistrationResponse(
|
||||
success=True,
|
||||
participant_id=request.participant_id,
|
||||
key_version=key_pair.version,
|
||||
registered_at=key_pair.created_at,
|
||||
error=None
|
||||
)
|
||||
|
||||
except KeyManagementError as e:
|
||||
logger.error(f"Key registration failed: {e}")
|
||||
return KeyRegistrationResponse(
|
||||
success=False,
|
||||
participant_id=request.participant_id,
|
||||
key_version=0,
|
||||
registered_at=datetime.utcnow(),
|
||||
error=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register key: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/keys/rotate")
|
||||
async def rotate_encryption_key(
|
||||
participant_id: str,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Rotate encryption keys for participant"""
|
||||
try:
|
||||
km = get_key_manager()
|
||||
|
||||
# Rotate keys
|
||||
new_key_pair = await km.rotate_keys(participant_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"participant_id": participant_id,
|
||||
"new_version": new_key_pair.version,
|
||||
"rotated_at": new_key_pair.created_at
|
||||
}
|
||||
|
||||
except KeyManagementError as e:
|
||||
logger.error(f"Key rotation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rotate keys: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/access/logs", response_model=AccessLogResponse)
|
||||
async def get_access_logs(
|
||||
query: AccessLogQuery = Depends(),
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Get access logs for confidential transactions"""
|
||||
try:
|
||||
# Query logs (in production, query from database)
|
||||
# For now, return empty response
|
||||
return AccessLogResponse(
|
||||
logs=[],
|
||||
total_count=0,
|
||||
has_more=False
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get access logs: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_confidential_status(
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Get status of confidential transaction system"""
|
||||
try:
|
||||
km = get_key_manager()
|
||||
enc_service = get_encryption_service()
|
||||
|
||||
# Get system status
|
||||
participants = await km.list_participants()
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"algorithm": "AES-256-GCM+X25519",
|
||||
"participants_count": len(participants),
|
||||
"transactions_count": 0, # Would query from database
|
||||
"audit_enabled": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -6,6 +6,7 @@ from fastapi import status as http_status
|
||||
from ..models import MarketplaceBidRequest, MarketplaceOfferView, MarketplaceStatsView
|
||||
from ..services import MarketplaceService
|
||||
from ..storage import SessionDep
|
||||
from ..metrics import marketplace_requests_total, marketplace_errors_total
|
||||
|
||||
router = APIRouter(tags=["marketplace"])
|
||||
|
||||
@ -26,11 +27,16 @@ async def list_marketplace_offers(
|
||||
limit: int = Query(default=100, ge=1, le=500),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> list[MarketplaceOfferView]:
|
||||
marketplace_requests_total.labels(endpoint="/marketplace/offers", method="GET").inc()
|
||||
service = _get_service(session)
|
||||
try:
|
||||
return service.list_offers(status=status_filter, limit=limit, offset=offset)
|
||||
except ValueError:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/offers", method="GET", error_type="invalid_request").inc()
|
||||
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail="invalid status filter") from None
|
||||
except Exception:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/offers", method="GET", error_type="internal").inc()
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
@ -39,8 +45,13 @@ async def list_marketplace_offers(
|
||||
summary="Get marketplace summary statistics",
|
||||
)
|
||||
async def get_marketplace_stats(*, session: SessionDep) -> MarketplaceStatsView:
|
||||
marketplace_requests_total.labels(endpoint="/marketplace/stats", method="GET").inc()
|
||||
service = _get_service(session)
|
||||
return service.get_stats()
|
||||
try:
|
||||
return service.get_stats()
|
||||
except Exception:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/stats", method="GET", error_type="internal").inc()
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
@ -52,6 +63,14 @@ async def submit_marketplace_bid(
|
||||
payload: MarketplaceBidRequest,
|
||||
session: SessionDep,
|
||||
) -> dict[str, str]:
|
||||
marketplace_requests_total.labels(endpoint="/marketplace/bids", method="POST").inc()
|
||||
service = _get_service(session)
|
||||
bid = service.create_bid(payload)
|
||||
return {"id": bid.id}
|
||||
try:
|
||||
bid = service.create_bid(payload)
|
||||
return {"id": bid.id}
|
||||
except ValueError:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/bids", method="POST", error_type="invalid_request").inc()
|
||||
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail="invalid bid data") from None
|
||||
except Exception:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/bids", method="POST", error_type="internal").inc()
|
||||
raise
|
||||
|
||||
303
apps/coordinator-api/src/app/routers/registry.py
Normal file
303
apps/coordinator-api/src/app/routers/registry.py
Normal file
@ -0,0 +1,303 @@
|
||||
"""
|
||||
Service registry router for dynamic service management
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from ..models.registry import (
|
||||
ServiceRegistry,
|
||||
ServiceDefinition,
|
||||
ServiceCategory
|
||||
)
|
||||
from ..models.registry_media import MEDIA_PROCESSING_SERVICES
|
||||
from ..models.registry_scientific import SCIENTIFIC_COMPUTING_SERVICES
|
||||
from ..models.registry_data import DATA_ANALYTICS_SERVICES
|
||||
from ..models.registry_gaming import GAMING_SERVICES
|
||||
from ..models.registry_devtools import DEVTOOLS_SERVICES
|
||||
from ..models.registry import AI_ML_SERVICES
|
||||
|
||||
router = APIRouter(prefix="/registry", tags=["service-registry"])
|
||||
|
||||
# Initialize service registry with all services
|
||||
def create_service_registry() -> ServiceRegistry:
|
||||
"""Create and populate the service registry"""
|
||||
all_services = {}
|
||||
|
||||
# Add all service categories
|
||||
all_services.update(AI_ML_SERVICES)
|
||||
all_services.update(MEDIA_PROCESSING_SERVICES)
|
||||
all_services.update(SCIENTIFIC_COMPUTING_SERVICES)
|
||||
all_services.update(DATA_ANALYTICS_SERVICES)
|
||||
all_services.update(GAMING_SERVICES)
|
||||
all_services.update(DEVTOOLS_SERVICES)
|
||||
|
||||
return ServiceRegistry(
|
||||
version="1.0.0",
|
||||
services=all_services
|
||||
)
|
||||
|
||||
# Global registry instance
|
||||
service_registry = create_service_registry()
|
||||
|
||||
|
||||
@router.get("/", response_model=ServiceRegistry)
|
||||
async def get_registry() -> ServiceRegistry:
|
||||
"""Get the complete service registry"""
|
||||
return service_registry
|
||||
|
||||
|
||||
@router.get("/services", response_model=List[ServiceDefinition])
|
||||
async def list_services(
|
||||
category: Optional[ServiceCategory] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[ServiceDefinition]:
|
||||
"""List all available services with optional filtering"""
|
||||
services = list(service_registry.services.values())
|
||||
|
||||
# Filter by category
|
||||
if category:
|
||||
services = [s for s in services if s.category == category]
|
||||
|
||||
# Search by name, description, or tags
|
||||
if search:
|
||||
search = search.lower()
|
||||
services = [
|
||||
s for s in services
|
||||
if (search in s.name.lower() or
|
||||
search in s.description.lower() or
|
||||
any(search in tag.lower() for tag in s.tags))
|
||||
]
|
||||
|
||||
return services
|
||||
|
||||
|
||||
@router.get("/services/{service_id}", response_model=ServiceDefinition)
|
||||
async def get_service(service_id: str) -> ServiceDefinition:
|
||||
"""Get a specific service definition"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
return service
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[Dict[str, Any]])
|
||||
async def list_categories() -> List[Dict[str, Any]]:
|
||||
"""List all service categories with counts"""
|
||||
category_counts = {}
|
||||
for service in service_registry.services.values():
|
||||
category = service.category.value
|
||||
if category not in category_counts:
|
||||
category_counts[category] = 0
|
||||
category_counts[category] += 1
|
||||
|
||||
return [
|
||||
{"category": cat, "count": count}
|
||||
for cat, count in category_counts.items()
|
||||
]
|
||||
|
||||
|
||||
@router.get("/categories/{category}", response_model=List[ServiceDefinition])
|
||||
async def get_services_by_category(category: ServiceCategory) -> List[ServiceDefinition]:
|
||||
"""Get all services in a specific category"""
|
||||
return service_registry.get_services_by_category(category)
|
||||
|
||||
|
||||
@router.get("/services/{service_id}/schema")
|
||||
async def get_service_schema(service_id: str) -> Dict[str, Any]:
|
||||
"""Get JSON schema for a service's input parameters"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
|
||||
# Convert input parameters to JSON schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in service.input_parameters:
|
||||
prop = {
|
||||
"type": param.type.value,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
prop["default"] = param.default
|
||||
if param.min_value is not None:
|
||||
prop["minimum"] = param.min_value
|
||||
if param.max_value is not None:
|
||||
prop["maximum"] = param.max_value
|
||||
if param.options:
|
||||
prop["enum"] = param.options
|
||||
if param.validation:
|
||||
prop.update(param.validation)
|
||||
|
||||
properties[param.name] = prop
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
|
||||
@router.get("/services/{service_id}/requirements")
|
||||
async def get_service_requirements(service_id: str) -> Dict[str, Any]:
|
||||
"""Get hardware requirements for a service"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"requirements": [
|
||||
{
|
||||
"component": req.component,
|
||||
"minimum": req.min_value,
|
||||
"recommended": req.recommended,
|
||||
"unit": req.unit
|
||||
}
|
||||
for req in service.requirements
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/services/{service_id}/pricing")
|
||||
async def get_service_pricing(service_id: str) -> Dict[str, Any]:
|
||||
"""Get pricing information for a service"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"pricing": [
|
||||
{
|
||||
"tier": tier.name,
|
||||
"model": tier.model.value,
|
||||
"unit_price": tier.unit_price,
|
||||
"min_charge": tier.min_charge,
|
||||
"currency": tier.currency,
|
||||
"description": tier.description
|
||||
}
|
||||
for tier in service.pricing
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/services/validate")
|
||||
async def validate_service_request(
|
||||
service_id: str,
|
||||
request_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate a service request against the service schema"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
|
||||
# Validate request data
|
||||
validation_result = {
|
||||
"valid": True,
|
||||
"errors": [],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# Check required parameters
|
||||
provided_params = set(request_data.keys())
|
||||
required_params = {p.name for p in service.input_parameters if p.required}
|
||||
missing_params = required_params - provided_params
|
||||
|
||||
if missing_params:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].extend([
|
||||
f"Missing required parameter: {param}"
|
||||
for param in missing_params
|
||||
])
|
||||
|
||||
# Validate parameter types and constraints
|
||||
for param in service.input_parameters:
|
||||
if param.name in request_data:
|
||||
value = request_data[param.name]
|
||||
|
||||
# Type validation (simplified)
|
||||
if param.type == "integer" and not isinstance(value, int):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be an integer"
|
||||
)
|
||||
elif param.type == "float" and not isinstance(value, (int, float)):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be a number"
|
||||
)
|
||||
elif param.type == "boolean" and not isinstance(value, bool):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be a boolean"
|
||||
)
|
||||
elif param.type == "array" and not isinstance(value, list):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be an array"
|
||||
)
|
||||
|
||||
# Value constraints
|
||||
if param.min_value is not None and value < param.min_value:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be >= {param.min_value}"
|
||||
)
|
||||
|
||||
if param.max_value is not None and value > param.max_value:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be <= {param.max_value}"
|
||||
)
|
||||
|
||||
# Enum options
|
||||
if param.options and value not in param.options:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be one of: {', '.join(param.options)}"
|
||||
)
|
||||
|
||||
return validation_result
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_registry_stats() -> Dict[str, Any]:
|
||||
"""Get registry statistics"""
|
||||
total_services = len(service_registry.services)
|
||||
category_counts = {}
|
||||
|
||||
for service in service_registry.services.values():
|
||||
category = service.category.value
|
||||
if category not in category_counts:
|
||||
category_counts[category] = 0
|
||||
category_counts[category] += 1
|
||||
|
||||
# Count unique pricing models
|
||||
pricing_models = set()
|
||||
for service in service_registry.services.values():
|
||||
for tier in service.pricing:
|
||||
pricing_models.add(tier.model.value)
|
||||
|
||||
return {
|
||||
"total_services": total_services,
|
||||
"categories": category_counts,
|
||||
"pricing_models": list(pricing_models),
|
||||
"last_updated": service_registry.last_updated.isoformat()
|
||||
}
|
||||
612
apps/coordinator-api/src/app/routers/services.py
Normal file
612
apps/coordinator-api/src/app/routers/services.py
Normal file
@ -0,0 +1,612 @@
|
||||
"""
|
||||
Services router for specific GPU workloads
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..deps import require_client_key
|
||||
from ..models import JobCreate, JobView, JobResult
|
||||
from ..models.services import (
|
||||
ServiceType,
|
||||
ServiceRequest,
|
||||
ServiceResponse,
|
||||
WhisperRequest,
|
||||
StableDiffusionRequest,
|
||||
LLMRequest,
|
||||
FFmpegRequest,
|
||||
BlenderRequest,
|
||||
)
|
||||
from ..models.registry import ServiceRegistry, service_registry
|
||||
from ..services import JobService
|
||||
from ..storage import SessionDep
|
||||
|
||||
router = APIRouter(tags=["services"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/services/{service_type}",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Submit a service-specific job",
|
||||
deprecated=True
|
||||
)
|
||||
async def submit_service_job(
|
||||
service_type: ServiceType,
|
||||
request_data: Dict[str, Any],
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
user_agent: str = Header(None),
|
||||
) -> ServiceResponse:
|
||||
"""Submit a job for a specific service type
|
||||
|
||||
DEPRECATED: Use /v1/registry/services/{service_id} endpoint instead.
|
||||
This endpoint will be removed in version 2.0.
|
||||
"""
|
||||
|
||||
# Add deprecation warning header
|
||||
from fastapi import Response
|
||||
response = Response()
|
||||
response.headers["X-Deprecated"] = "true"
|
||||
response.headers["X-Deprecation-Message"] = "Use /v1/registry/services/{service_id} instead"
|
||||
|
||||
# Check if service exists in registry
|
||||
service = service_registry.get_service(service_type.value)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_type} not found"
|
||||
)
|
||||
|
||||
# Validate request against service schema
|
||||
validation_result = await validate_service_request(service_type.value, request_data)
|
||||
if not validation_result["valid"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid request: {', '.join(validation_result['errors'])}"
|
||||
)
|
||||
|
||||
# Create service request wrapper
|
||||
service_request = ServiceRequest(
|
||||
service_type=service_type,
|
||||
request_data=request_data
|
||||
)
|
||||
|
||||
# Validate and parse service-specific request
|
||||
try:
|
||||
typed_request = service_request.get_service_request()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid request for {service_type}: {str(e)}"
|
||||
)
|
||||
|
||||
# Get constraints from service request
|
||||
constraints = typed_request.get_constraints()
|
||||
|
||||
# Create job with service-specific payload
|
||||
job_payload = {
|
||||
"service_type": service_type.value,
|
||||
"service_request": request_data,
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=constraints,
|
||||
ttl_seconds=900 # Default 15 minutes
|
||||
)
|
||||
|
||||
# Submit job
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=service_type,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# Whisper endpoints
|
||||
@router.post(
|
||||
"/services/whisper/transcribe",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Transcribe audio using Whisper"
|
||||
)
|
||||
async def whisper_transcribe(
|
||||
request: WhisperRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Transcribe audio file using Whisper"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.WHISPER.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=900
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.WHISPER,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/services/whisper/translate",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Translate audio using Whisper"
|
||||
)
|
||||
async def whisper_translate(
|
||||
request: WhisperRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Translate audio file using Whisper"""
|
||||
# Force task to be translate
|
||||
request.task = "translate"
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.WHISPER.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=900
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.WHISPER,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# Stable Diffusion endpoints
|
||||
@router.post(
|
||||
"/services/stable-diffusion/generate",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Generate images using Stable Diffusion"
|
||||
)
|
||||
async def stable_diffusion_generate(
|
||||
request: StableDiffusionRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Generate images using Stable Diffusion"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.STABLE_DIFFUSION.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=600 # 10 minutes for image generation
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.STABLE_DIFFUSION,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/services/stable-diffusion/img2img",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Image-to-image generation"
|
||||
)
|
||||
async def stable_diffusion_img2img(
|
||||
request: StableDiffusionRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Image-to-image generation using Stable Diffusion"""
|
||||
# Add img2img specific parameters
|
||||
request_data = request.dict()
|
||||
request_data["mode"] = "img2img"
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.STABLE_DIFFUSION.value,
|
||||
"service_request": request_data,
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=600
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.STABLE_DIFFUSION,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# LLM Inference endpoints
|
||||
@router.post(
|
||||
"/services/llm/inference",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Run LLM inference"
|
||||
)
|
||||
async def llm_inference(
|
||||
request: LLMRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Run inference on a language model"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.LLM_INFERENCE.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=300 # 5 minutes for text generation
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.LLM_INFERENCE,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/services/llm/stream",
|
||||
summary="Stream LLM inference"
|
||||
)
|
||||
async def llm_stream(
|
||||
request: LLMRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
):
|
||||
"""Stream LLM inference response"""
|
||||
# Force streaming mode
|
||||
request.stream = True
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.LLM_INFERENCE.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=300
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
# Return streaming response
|
||||
# This would implement WebSocket or Server-Sent Events
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.LLM_INFERENCE,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# FFmpeg endpoints
|
||||
@router.post(
|
||||
"/services/ffmpeg/transcode",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Transcode video using FFmpeg"
|
||||
)
|
||||
async def ffmpeg_transcode(
|
||||
request: FFmpegRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Transcode video using FFmpeg"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.FFMPEG.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
# Adjust TTL based on video length (would need to probe video)
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=1800 # 30 minutes for video transcoding
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.FFMPEG,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# Blender endpoints
|
||||
@router.post(
|
||||
"/services/blender/render",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Render using Blender"
|
||||
)
|
||||
async def blender_render(
|
||||
request: BlenderRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Render scene using Blender"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.BLENDER.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
# Adjust TTL based on frame count
|
||||
frame_count = request.frame_end - request.frame_start + 1
|
||||
estimated_time = frame_count * 30 # 30 seconds per frame estimate
|
||||
ttl_seconds = max(600, estimated_time) # Minimum 10 minutes
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=ttl_seconds
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.BLENDER,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# Utility endpoints
|
||||
@router.get(
|
||||
"/services",
|
||||
summary="List available services"
|
||||
)
|
||||
async def list_services() -> Dict[str, Any]:
|
||||
"""List all available service types and their capabilities"""
|
||||
return {
|
||||
"services": [
|
||||
{
|
||||
"type": ServiceType.WHISPER.value,
|
||||
"name": "Whisper Speech Recognition",
|
||||
"description": "Transcribe and translate audio files",
|
||||
"models": [m.value for m in WhisperModel],
|
||||
"constraints": {
|
||||
"gpu": "nvidia",
|
||||
"min_vram_gb": 1,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": ServiceType.STABLE_DIFFUSION.value,
|
||||
"name": "Stable Diffusion",
|
||||
"description": "Generate images from text prompts",
|
||||
"models": [m.value for m in SDModel],
|
||||
"constraints": {
|
||||
"gpu": "nvidia",
|
||||
"min_vram_gb": 4,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": ServiceType.LLM_INFERENCE.value,
|
||||
"name": "LLM Inference",
|
||||
"description": "Run inference on large language models",
|
||||
"models": [m.value for m in LLMModel],
|
||||
"constraints": {
|
||||
"gpu": "nvidia",
|
||||
"min_vram_gb": 8,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": ServiceType.FFMPEG.value,
|
||||
"name": "FFmpeg Video Processing",
|
||||
"description": "Transcode and process video files",
|
||||
"codecs": [c.value for c in FFmpegCodec],
|
||||
"constraints": {
|
||||
"gpu": "any",
|
||||
"min_vram_gb": 0,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": ServiceType.BLENDER.value,
|
||||
"name": "Blender Rendering",
|
||||
"description": "Render 3D scenes using Blender",
|
||||
"engines": [e.value for e in BlenderEngine],
|
||||
"constraints": {
|
||||
"gpu": "any",
|
||||
"min_vram_gb": 4,
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/services/{service_type}/schema",
|
||||
summary="Get service request schema",
|
||||
deprecated=True
|
||||
)
|
||||
async def get_service_schema(service_type: ServiceType) -> Dict[str, Any]:
|
||||
"""Get the JSON schema for a specific service type
|
||||
|
||||
DEPRECATED: Use /v1/registry/services/{service_id}/schema instead.
|
||||
This endpoint will be removed in version 2.0.
|
||||
"""
|
||||
# Get service from registry
|
||||
service = service_registry.get_service(service_type.value)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_type} not found"
|
||||
)
|
||||
|
||||
# Build schema from service definition
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in service.input_parameters:
|
||||
prop = {
|
||||
"type": param.type.value,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
prop["default"] = param.default
|
||||
if param.min_value is not None:
|
||||
prop["minimum"] = param.min_value
|
||||
if param.max_value is not None:
|
||||
prop["maximum"] = param.max_value
|
||||
if param.options:
|
||||
prop["enum"] = param.options
|
||||
if param.validation:
|
||||
prop.update(param.validation)
|
||||
|
||||
properties[param.name] = prop
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
return {
|
||||
"service_type": service_type.value,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
|
||||
async def validate_service_request(service_id: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate a service request against the service schema"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
return {"valid": False, "errors": [f"Service {service_id} not found"]}
|
||||
|
||||
validation_result = {
|
||||
"valid": True,
|
||||
"errors": [],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# Check required parameters
|
||||
provided_params = set(request_data.keys())
|
||||
required_params = {p.name for p in service.input_parameters if p.required}
|
||||
missing_params = required_params - provided_params
|
||||
|
||||
if missing_params:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].extend([
|
||||
f"Missing required parameter: {param}"
|
||||
for param in missing_params
|
||||
])
|
||||
|
||||
# Validate parameter types and constraints
|
||||
for param in service.input_parameters:
|
||||
if param.name in request_data:
|
||||
value = request_data[param.name]
|
||||
|
||||
# Type validation (simplified)
|
||||
if param.type == "integer" and not isinstance(value, int):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be an integer"
|
||||
)
|
||||
elif param.type == "float" and not isinstance(value, (int, float)):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be a number"
|
||||
)
|
||||
elif param.type == "boolean" and not isinstance(value, bool):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be a boolean"
|
||||
)
|
||||
elif param.type == "array" and not isinstance(value, list):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be an array"
|
||||
)
|
||||
|
||||
# Value constraints
|
||||
if param.min_value is not None and value < param.min_value:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be >= {param.min_value}"
|
||||
)
|
||||
|
||||
if param.max_value is not None and value > param.max_value:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be <= {param.max_value}"
|
||||
)
|
||||
|
||||
# Enum options
|
||||
if param.options and value not in param.options:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be one of: {', '.join(param.options)}"
|
||||
)
|
||||
|
||||
return validation_result
|
||||
|
||||
|
||||
# Import models for type hints
|
||||
from ..models.services import (
|
||||
WhisperModel,
|
||||
SDModel,
|
||||
LLMModel,
|
||||
FFmpegCodec,
|
||||
FFmpegPreset,
|
||||
BlenderEngine,
|
||||
BlenderFormat,
|
||||
)
|
||||
362
apps/coordinator-api/src/app/services/access_control.py
Normal file
362
apps/coordinator-api/src/app/services/access_control.py
Normal file
@ -0,0 +1,362 @@
|
||||
"""
|
||||
Access control service for confidential transactions
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import json
|
||||
import re
|
||||
|
||||
from ..models import ConfidentialAccessRequest, ConfidentialAccessLog
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AccessPurpose(str, Enum):
|
||||
"""Standard access purposes"""
|
||||
SETTLEMENT = "settlement"
|
||||
AUDIT = "audit"
|
||||
COMPLIANCE = "compliance"
|
||||
DISPUTE = "dispute"
|
||||
SUPPORT = "support"
|
||||
REPORTING = "reporting"
|
||||
|
||||
|
||||
class AccessLevel(str, Enum):
|
||||
"""Access levels for confidential data"""
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class ParticipantRole(str, Enum):
|
||||
"""Roles for transaction participants"""
|
||||
CLIENT = "client"
|
||||
MINER = "miner"
|
||||
COORDINATOR = "coordinator"
|
||||
AUDITOR = "auditor"
|
||||
REGULATOR = "regulator"
|
||||
|
||||
|
||||
class PolicyStore:
|
||||
"""Storage for access control policies"""
|
||||
|
||||
def __init__(self):
|
||||
self._policies: Dict[str, Dict] = {}
|
||||
self._role_permissions: Dict[ParticipantRole, Set[str]] = {
|
||||
ParticipantRole.CLIENT: {"read_own", "settlement_own"},
|
||||
ParticipantRole.MINER: {"read_assigned", "settlement_assigned"},
|
||||
ParticipantRole.COORDINATOR: {"read_all", "admin_all"},
|
||||
ParticipantRole.AUDITOR: {"read_all", "audit_all"},
|
||||
ParticipantRole.REGULATOR: {"read_all", "compliance_all"}
|
||||
}
|
||||
self._load_default_policies()
|
||||
|
||||
def _load_default_policies(self):
|
||||
"""Load default access policies"""
|
||||
# Client can access their own transactions
|
||||
self._policies["client_own_data"] = {
|
||||
"participants": ["client"],
|
||||
"conditions": {
|
||||
"transaction_client_id": "{requester}",
|
||||
"purpose": ["settlement", "dispute", "support"]
|
||||
},
|
||||
"access_level": AccessLevel.READ,
|
||||
"time_restrictions": None
|
||||
}
|
||||
|
||||
# Miner can access assigned transactions
|
||||
self._policies["miner_assigned_data"] = {
|
||||
"participants": ["miner"],
|
||||
"conditions": {
|
||||
"transaction_miner_id": "{requester}",
|
||||
"purpose": ["settlement"]
|
||||
},
|
||||
"access_level": AccessLevel.READ,
|
||||
"time_restrictions": None
|
||||
}
|
||||
|
||||
# Coordinator has full access
|
||||
self._policies["coordinator_full"] = {
|
||||
"participants": ["coordinator"],
|
||||
"conditions": {},
|
||||
"access_level": AccessLevel.ADMIN,
|
||||
"time_restrictions": None
|
||||
}
|
||||
|
||||
# Auditor access for compliance
|
||||
self._policies["auditor_compliance"] = {
|
||||
"participants": ["auditor", "regulator"],
|
||||
"conditions": {
|
||||
"purpose": ["audit", "compliance"]
|
||||
},
|
||||
"access_level": AccessLevel.READ,
|
||||
"time_restrictions": {
|
||||
"business_hours_only": True,
|
||||
"retention_days": 2555 # 7 years
|
||||
}
|
||||
}
|
||||
|
||||
def get_policy(self, policy_id: str) -> Optional[Dict]:
|
||||
"""Get access policy by ID"""
|
||||
return self._policies.get(policy_id)
|
||||
|
||||
def list_policies(self) -> List[str]:
|
||||
"""List all policy IDs"""
|
||||
return list(self._policies.keys())
|
||||
|
||||
def add_policy(self, policy_id: str, policy: Dict):
|
||||
"""Add new access policy"""
|
||||
self._policies[policy_id] = policy
|
||||
|
||||
def get_role_permissions(self, role: ParticipantRole) -> Set[str]:
|
||||
"""Get permissions for a role"""
|
||||
return self._role_permissions.get(role, set())
|
||||
|
||||
|
||||
class AccessController:
|
||||
"""Controls access to confidential transaction data"""
|
||||
|
||||
def __init__(self, policy_store: PolicyStore):
|
||||
self.policy_store = policy_store
|
||||
self._access_cache: Dict[str, Dict] = {}
|
||||
self._cache_ttl = timedelta(minutes=5)
|
||||
|
||||
def verify_access(self, request: ConfidentialAccessRequest) -> bool:
|
||||
"""Verify if requester has access rights"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = self._get_cache_key(request)
|
||||
cached_result = self._get_cached_result(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result["allowed"]
|
||||
|
||||
# Get participant info
|
||||
participant_info = self._get_participant_info(request.requester)
|
||||
if not participant_info:
|
||||
logger.warning(f"Unknown participant: {request.requester}")
|
||||
return False
|
||||
|
||||
# Check role-based permissions
|
||||
role = participant_info.get("role")
|
||||
if not self._check_role_permissions(role, request):
|
||||
return False
|
||||
|
||||
# Check transaction-specific policies
|
||||
transaction = self._get_transaction(request.transaction_id)
|
||||
if not transaction:
|
||||
logger.warning(f"Transaction not found: {request.transaction_id}")
|
||||
return False
|
||||
|
||||
# Apply access policies
|
||||
allowed = self._apply_policies(request, participant_info, transaction)
|
||||
|
||||
# Cache result
|
||||
self._cache_result(cache_key, allowed)
|
||||
|
||||
return allowed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Access verification failed: {e}")
|
||||
return False
|
||||
|
||||
def _check_role_permissions(self, role: str, request: ConfidentialAccessRequest) -> bool:
|
||||
"""Check if role grants access for this purpose"""
|
||||
try:
|
||||
participant_role = ParticipantRole(role.lower())
|
||||
permissions = self.policy_store.get_role_permissions(participant_role)
|
||||
|
||||
# Check purpose-based permissions
|
||||
if request.purpose == "settlement":
|
||||
return "settlement" in permissions or "settlement_own" in permissions
|
||||
elif request.purpose == "audit":
|
||||
return "audit" in permissions or "audit_all" in permissions
|
||||
elif request.purpose == "compliance":
|
||||
return "compliance" in permissions or "compliance_all" in permissions
|
||||
elif request.purpose == "dispute":
|
||||
return "dispute" in permissions or "read_own" in permissions
|
||||
elif request.purpose == "support":
|
||||
return "support" in permissions or "read_all" in permissions
|
||||
else:
|
||||
return "read" in permissions or "read_all" in permissions
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid role: {role}")
|
||||
return False
|
||||
|
||||
def _apply_policies(
|
||||
self,
|
||||
request: ConfidentialAccessRequest,
|
||||
participant_info: Dict,
|
||||
transaction: Dict
|
||||
) -> bool:
|
||||
"""Apply access policies to request"""
|
||||
# Check if participant is in transaction participants list
|
||||
if request.requester not in transaction.get("participants", []):
|
||||
# Only coordinators, auditors, and regulators can access non-participant data
|
||||
role = participant_info.get("role", "").lower()
|
||||
if role not in ["coordinator", "auditor", "regulator"]:
|
||||
return False
|
||||
|
||||
# Check time-based restrictions
|
||||
if not self._check_time_restrictions(request.purpose, participant_info.get("role")):
|
||||
return False
|
||||
|
||||
# Check business hours for auditors
|
||||
if participant_info.get("role") == "auditor" and not self._is_business_hours():
|
||||
return False
|
||||
|
||||
# Check retention periods
|
||||
if not self._check_retention_period(transaction, participant_info.get("role")):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_time_restrictions(self, purpose: str, role: Optional[str]) -> bool:
|
||||
"""Check time-based access restrictions"""
|
||||
# No restrictions for settlement and dispute
|
||||
if purpose in ["settlement", "dispute"]:
|
||||
return True
|
||||
|
||||
# Audit and compliance only during business hours for non-coordinators
|
||||
if purpose in ["audit", "compliance"] and role not in ["coordinator"]:
|
||||
return self._is_business_hours()
|
||||
|
||||
return True
|
||||
|
||||
def _is_business_hours(self) -> bool:
|
||||
"""Check if current time is within business hours"""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Monday-Friday, 9 AM - 5 PM UTC
|
||||
if now.weekday() >= 5: # Weekend
|
||||
return False
|
||||
|
||||
if 9 <= now.hour < 17:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _check_retention_period(self, transaction: Dict, role: Optional[str]) -> bool:
|
||||
"""Check if data is within retention period for role"""
|
||||
transaction_date = transaction.get("timestamp", datetime.utcnow())
|
||||
|
||||
# Different retention periods for different roles
|
||||
if role == "regulator":
|
||||
retention_days = 2555 # 7 years
|
||||
elif role == "auditor":
|
||||
retention_days = 1825 # 5 years
|
||||
elif role == "coordinator":
|
||||
retention_days = 3650 # 10 years
|
||||
else:
|
||||
retention_days = 365 # 1 year
|
||||
|
||||
expiry_date = transaction_date + timedelta(days=retention_days)
|
||||
|
||||
return datetime.utcnow() <= expiry_date
|
||||
|
||||
def _get_participant_info(self, participant_id: str) -> Optional[Dict]:
|
||||
"""Get participant information"""
|
||||
# In production, query from database
|
||||
# For now, return mock data
|
||||
if participant_id.startswith("client-"):
|
||||
return {"id": participant_id, "role": "client", "active": True}
|
||||
elif participant_id.startswith("miner-"):
|
||||
return {"id": participant_id, "role": "miner", "active": True}
|
||||
elif participant_id.startswith("coordinator-"):
|
||||
return {"id": participant_id, "role": "coordinator", "active": True}
|
||||
elif participant_id.startswith("auditor-"):
|
||||
return {"id": participant_id, "role": "auditor", "active": True}
|
||||
elif participant_id.startswith("regulator-"):
|
||||
return {"id": participant_id, "role": "regulator", "active": True}
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_transaction(self, transaction_id: str) -> Optional[Dict]:
|
||||
"""Get transaction information"""
|
||||
# In production, query from database
|
||||
# For now, return mock data
|
||||
return {
|
||||
"transaction_id": transaction_id,
|
||||
"participants": ["client-456", "miner-789"],
|
||||
"timestamp": datetime.utcnow(),
|
||||
"status": "completed"
|
||||
}
|
||||
|
||||
def _get_cache_key(self, request: ConfidentialAccessRequest) -> str:
|
||||
"""Generate cache key for access request"""
|
||||
return f"{request.requester}:{request.transaction_id}:{request.purpose}"
|
||||
|
||||
def _get_cached_result(self, cache_key: str) -> Optional[Dict]:
|
||||
"""Get cached access result"""
|
||||
if cache_key in self._access_cache:
|
||||
cached = self._access_cache[cache_key]
|
||||
if datetime.utcnow() - cached["timestamp"] < self._cache_ttl:
|
||||
return cached
|
||||
else:
|
||||
del self._access_cache[cache_key]
|
||||
return None
|
||||
|
||||
def _cache_result(self, cache_key: str, allowed: bool):
|
||||
"""Cache access result"""
|
||||
self._access_cache[cache_key] = {
|
||||
"allowed": allowed,
|
||||
"timestamp": datetime.utcnow()
|
||||
}
|
||||
|
||||
def create_access_policy(
|
||||
self,
|
||||
name: str,
|
||||
participants: List[str],
|
||||
conditions: Dict[str, Any],
|
||||
access_level: AccessLevel
|
||||
) -> str:
|
||||
"""Create a new access policy"""
|
||||
policy_id = f"policy_{datetime.utcnow().timestamp()}"
|
||||
|
||||
policy = {
|
||||
"participants": participants,
|
||||
"conditions": conditions,
|
||||
"access_level": access_level,
|
||||
"time_restrictions": conditions.get("time_restrictions"),
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
self.policy_store.add_policy(policy_id, policy)
|
||||
logger.info(f"Created access policy: {policy_id}")
|
||||
|
||||
return policy_id
|
||||
|
||||
def revoke_access(self, participant_id: str, transaction_id: Optional[str] = None):
|
||||
"""Revoke access for participant"""
|
||||
# In production, update database
|
||||
# For now, clear cache
|
||||
keys_to_remove = []
|
||||
for key in self._access_cache:
|
||||
if key.startswith(f"{participant_id}:"):
|
||||
if transaction_id is None or key.split(":")[1] == transaction_id:
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._access_cache[key]
|
||||
|
||||
logger.info(f"Revoked access for participant: {participant_id}")
|
||||
|
||||
def get_access_summary(self, participant_id: str) -> Dict:
|
||||
"""Get summary of participant's access rights"""
|
||||
participant_info = self._get_participant_info(participant_id)
|
||||
if not participant_info:
|
||||
return {"error": "Participant not found"}
|
||||
|
||||
role = participant_info.get("role")
|
||||
permissions = self.policy_store.get_role_permissions(ParticipantRole(role))
|
||||
|
||||
return {
|
||||
"participant_id": participant_id,
|
||||
"role": role,
|
||||
"permissions": list(permissions),
|
||||
"active": participant_info.get("active", False)
|
||||
}
|
||||
532
apps/coordinator-api/src/app/services/audit_logging.py
Normal file
532
apps/coordinator-api/src/app/services/audit_logging.py
Normal file
@ -0,0 +1,532 @@
|
||||
"""
|
||||
Audit logging service for privacy compliance
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import gzip
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from ..models import ConfidentialAccessLog
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditEvent:
|
||||
"""Structured audit event"""
|
||||
event_id: str
|
||||
timestamp: datetime
|
||||
event_type: str
|
||||
participant_id: str
|
||||
transaction_id: Optional[str]
|
||||
action: str
|
||||
resource: str
|
||||
outcome: str
|
||||
details: Dict[str, Any]
|
||||
ip_address: Optional[str]
|
||||
user_agent: Optional[str]
|
||||
authorization: Optional[str]
|
||||
signature: Optional[str]
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""Tamper-evident audit logging for privacy compliance"""
|
||||
|
||||
def __init__(self, log_dir: str = "/var/log/aitbc/audit"):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Current log file
|
||||
self.current_file = None
|
||||
self.current_hash = None
|
||||
|
||||
# Async writer task
|
||||
self.write_queue = asyncio.Queue(maxsize=10000)
|
||||
self.writer_task = None
|
||||
|
||||
# Chain of hashes for integrity
|
||||
self.chain_hash = self._load_chain_hash()
|
||||
|
||||
async def start(self):
|
||||
"""Start the background writer task"""
|
||||
if self.writer_task is None:
|
||||
self.writer_task = asyncio.create_task(self._background_writer())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the background writer task"""
|
||||
if self.writer_task:
|
||||
self.writer_task.cancel()
|
||||
try:
|
||||
await self.writer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.writer_task = None
|
||||
|
||||
async def log_access(
|
||||
self,
|
||||
participant_id: str,
|
||||
transaction_id: Optional[str],
|
||||
action: str,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
authorization: Optional[str] = None
|
||||
):
|
||||
"""Log access to confidential data"""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
timestamp=datetime.utcnow(),
|
||||
event_type="access",
|
||||
participant_id=participant_id,
|
||||
transaction_id=transaction_id,
|
||||
action=action,
|
||||
resource="confidential_transaction",
|
||||
outcome=outcome,
|
||||
details=details or {},
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
authorization=authorization,
|
||||
signature=None
|
||||
)
|
||||
|
||||
# Add signature for tamper-evidence
|
||||
event.signature = self._sign_event(event)
|
||||
|
||||
# Queue for writing
|
||||
await self.write_queue.put(event)
|
||||
|
||||
async def log_key_operation(
|
||||
self,
|
||||
participant_id: str,
|
||||
operation: str,
|
||||
key_version: int,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log key management operations"""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
timestamp=datetime.utcnow(),
|
||||
event_type="key_operation",
|
||||
participant_id=participant_id,
|
||||
transaction_id=None,
|
||||
action=operation,
|
||||
resource="encryption_key",
|
||||
outcome=outcome,
|
||||
details={**(details or {}), "key_version": key_version},
|
||||
ip_address=None,
|
||||
user_agent=None,
|
||||
authorization=None,
|
||||
signature=None
|
||||
)
|
||||
|
||||
event.signature = self._sign_event(event)
|
||||
await self.write_queue.put(event)
|
||||
|
||||
async def log_policy_change(
|
||||
self,
|
||||
participant_id: str,
|
||||
policy_id: str,
|
||||
change_type: str,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log access policy changes"""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
timestamp=datetime.utcnow(),
|
||||
event_type="policy_change",
|
||||
participant_id=participant_id,
|
||||
transaction_id=None,
|
||||
action=change_type,
|
||||
resource="access_policy",
|
||||
outcome=outcome,
|
||||
details={**(details or {}), "policy_id": policy_id},
|
||||
ip_address=None,
|
||||
user_agent=None,
|
||||
authorization=None,
|
||||
signature=None
|
||||
)
|
||||
|
||||
event.signature = self._sign_event(event)
|
||||
await self.write_queue.put(event)
|
||||
|
||||
def query_logs(
|
||||
self,
|
||||
participant_id: Optional[str] = None,
|
||||
transaction_id: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
) -> List[AuditEvent]:
|
||||
"""Query audit logs"""
|
||||
results = []
|
||||
|
||||
# Get list of log files to search
|
||||
log_files = self._get_log_files(start_time, end_time)
|
||||
|
||||
for log_file in log_files:
|
||||
try:
|
||||
# Read and decompress if needed
|
||||
if log_file.suffix == ".gz":
|
||||
with gzip.open(log_file, "rt") as f:
|
||||
for line in f:
|
||||
event = self._parse_log_line(line.strip())
|
||||
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
|
||||
results.append(event)
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
else:
|
||||
with open(log_file, "r") as f:
|
||||
for line in f:
|
||||
event = self._parse_log_line(line.strip())
|
||||
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
|
||||
results.append(event)
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read log file {log_file}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
results.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def verify_integrity(self, start_date: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Verify integrity of audit logs"""
|
||||
if start_date is None:
|
||||
start_date = datetime.utcnow() - timedelta(days=30)
|
||||
|
||||
results = {
|
||||
"verified_files": 0,
|
||||
"total_files": 0,
|
||||
"integrity_violations": [],
|
||||
"chain_valid": True
|
||||
}
|
||||
|
||||
log_files = self._get_log_files(start_date)
|
||||
|
||||
for log_file in log_files:
|
||||
results["total_files"] += 1
|
||||
|
||||
try:
|
||||
# Verify file hash
|
||||
file_hash = self._calculate_file_hash(log_file)
|
||||
stored_hash = self._get_stored_hash(log_file)
|
||||
|
||||
if file_hash != stored_hash:
|
||||
results["integrity_violations"].append({
|
||||
"file": str(log_file),
|
||||
"expected": stored_hash,
|
||||
"actual": file_hash
|
||||
})
|
||||
results["chain_valid"] = False
|
||||
else:
|
||||
results["verified_files"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify {log_file}: {e}")
|
||||
results["integrity_violations"].append({
|
||||
"file": str(log_file),
|
||||
"error": str(e)
|
||||
})
|
||||
results["chain_valid"] = False
|
||||
|
||||
return results
|
||||
|
||||
def export_logs(
|
||||
self,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
format: str = "json",
|
||||
include_signatures: bool = True
|
||||
) -> str:
|
||||
"""Export audit logs for compliance reporting"""
|
||||
events = self.query_logs(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=10000
|
||||
)
|
||||
|
||||
if format == "json":
|
||||
export_data = {
|
||||
"export_metadata": {
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": end_time.isoformat(),
|
||||
"event_count": len(events),
|
||||
"exported_at": datetime.utcnow().isoformat(),
|
||||
"include_signatures": include_signatures
|
||||
},
|
||||
"events": []
|
||||
}
|
||||
|
||||
for event in events:
|
||||
event_dict = asdict(event)
|
||||
event_dict["timestamp"] = event.timestamp.isoformat()
|
||||
|
||||
if not include_signatures:
|
||||
event_dict.pop("signature", None)
|
||||
|
||||
export_data["events"].append(event_dict)
|
||||
|
||||
return json.dumps(export_data, indent=2)
|
||||
|
||||
elif format == "csv":
|
||||
import csv
|
||||
import io
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Header
|
||||
header = [
|
||||
"event_id", "timestamp", "event_type", "participant_id",
|
||||
"transaction_id", "action", "resource", "outcome",
|
||||
"ip_address", "user_agent"
|
||||
]
|
||||
if include_signatures:
|
||||
header.append("signature")
|
||||
writer.writerow(header)
|
||||
|
||||
# Events
|
||||
for event in events:
|
||||
row = [
|
||||
event.event_id,
|
||||
event.timestamp.isoformat(),
|
||||
event.event_type,
|
||||
event.participant_id,
|
||||
event.transaction_id,
|
||||
event.action,
|
||||
event.resource,
|
||||
event.outcome,
|
||||
event.ip_address,
|
||||
event.user_agent
|
||||
]
|
||||
if include_signatures:
|
||||
row.append(event.signature)
|
||||
writer.writerow(row)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported export format: {format}")
|
||||
|
||||
async def _background_writer(self):
|
||||
"""Background task for writing audit events"""
|
||||
while True:
|
||||
try:
|
||||
# Get batch of events
|
||||
events = []
|
||||
while len(events) < 100:
|
||||
try:
|
||||
# Use asyncio.wait_for for timeout
|
||||
event = await asyncio.wait_for(
|
||||
self.write_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
events.append(event)
|
||||
except asyncio.TimeoutError:
|
||||
if events:
|
||||
break
|
||||
continue
|
||||
|
||||
# Write events
|
||||
if events:
|
||||
self._write_events(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background writer error: {e}")
|
||||
# Brief pause to avoid error loops
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def _write_events(self, events: List[AuditEvent]):
|
||||
"""Write events to current log file"""
|
||||
try:
|
||||
self._rotate_if_needed()
|
||||
|
||||
with open(self.current_file, "a") as f:
|
||||
for event in events:
|
||||
# Convert to JSON line
|
||||
event_dict = asdict(event)
|
||||
event_dict["timestamp"] = event.timestamp.isoformat()
|
||||
|
||||
# Write with signature
|
||||
line = json.dumps(event_dict, separators=(",", ":")) + "\n"
|
||||
f.write(line)
|
||||
f.flush()
|
||||
|
||||
# Update chain hash
|
||||
self._update_chain_hash(events[-1])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audit events: {e}")
|
||||
|
||||
def _rotate_if_needed(self):
|
||||
"""Rotate log file if needed"""
|
||||
now = datetime.utcnow()
|
||||
today = now.date()
|
||||
|
||||
# Check if we need a new file
|
||||
if self.current_file is None:
|
||||
self._new_log_file(today)
|
||||
else:
|
||||
file_date = datetime.fromisoformat(
|
||||
self.current_file.stem.split("_")[1]
|
||||
).date()
|
||||
|
||||
if file_date != today:
|
||||
self._new_log_file(today)
|
||||
|
||||
def _new_log_file(self, date):
|
||||
"""Create new log file for date"""
|
||||
filename = f"audit_{date.isoformat()}.log"
|
||||
self.current_file = self.log_dir / filename
|
||||
|
||||
# Write header with metadata
|
||||
if not self.current_file.exists():
|
||||
header = {
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"version": "1.0",
|
||||
"format": "jsonl",
|
||||
"previous_hash": self.chain_hash
|
||||
}
|
||||
|
||||
with open(self.current_file, "w") as f:
|
||||
f.write(f"# {json.dumps(header)}\n")
|
||||
|
||||
def _generate_event_id(self) -> str:
|
||||
"""Generate unique event ID"""
|
||||
return f"evt_{datetime.utcnow().timestamp()}_{os.urandom(4).hex()}"
|
||||
|
||||
def _sign_event(self, event: AuditEvent) -> str:
|
||||
"""Sign event for tamper-evidence"""
|
||||
# Create canonical representation
|
||||
event_data = {
|
||||
"event_id": event.event_id,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"participant_id": event.participant_id,
|
||||
"action": event.action,
|
||||
"outcome": event.outcome
|
||||
}
|
||||
|
||||
# Hash with previous chain hash
|
||||
data = json.dumps(event_data, separators=(",", ":"), sort_keys=True)
|
||||
combined = f"{self.chain_hash}:{data}".encode()
|
||||
|
||||
return hashlib.sha256(combined).hexdigest()
|
||||
|
||||
def _update_chain_hash(self, last_event: AuditEvent):
|
||||
"""Update chain hash with new event"""
|
||||
self.chain_hash = last_event.signature or self.chain_hash
|
||||
|
||||
# Store chain hash for integrity checking
|
||||
chain_file = self.log_dir / "chain.hash"
|
||||
with open(chain_file, "w") as f:
|
||||
f.write(self.chain_hash)
|
||||
|
||||
def _load_chain_hash(self) -> str:
|
||||
"""Load previous chain hash"""
|
||||
chain_file = self.log_dir / "chain.hash"
|
||||
if chain_file.exists():
|
||||
with open(chain_file, "r") as f:
|
||||
return f.read().strip()
|
||||
return "0" * 64 # Initial hash
|
||||
|
||||
def _get_log_files(self, start_time: Optional[datetime], end_time: Optional[datetime]) -> List[Path]:
|
||||
"""Get list of log files to search"""
|
||||
files = []
|
||||
|
||||
for file in self.log_dir.glob("audit_*.log*"):
|
||||
try:
|
||||
# Extract date from filename
|
||||
date_str = file.stem.split("_")[1]
|
||||
file_date = datetime.fromisoformat(date_str).date()
|
||||
|
||||
# Check if file is in range
|
||||
file_start = datetime.combine(file_date, datetime.min.time())
|
||||
file_end = file_start + timedelta(days=1)
|
||||
|
||||
if (not start_time or file_end >= start_time) and \
|
||||
(not end_time or file_start <= end_time):
|
||||
files.append(file)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return sorted(files)
|
||||
|
||||
def _parse_log_line(self, line: str) -> Optional[AuditEvent]:
|
||||
"""Parse log line into event"""
|
||||
if line.startswith("#"):
|
||||
return None # Skip header
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
|
||||
return AuditEvent(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse log line: {e}")
|
||||
return None
|
||||
|
||||
def _matches_query(
|
||||
self,
|
||||
event: Optional[AuditEvent],
|
||||
participant_id: Optional[str],
|
||||
transaction_id: Optional[str],
|
||||
event_type: Optional[str],
|
||||
start_time: Optional[datetime],
|
||||
end_time: Optional[datetime]
|
||||
) -> bool:
|
||||
"""Check if event matches query criteria"""
|
||||
if not event:
|
||||
return False
|
||||
|
||||
if participant_id and event.participant_id != participant_id:
|
||||
return False
|
||||
|
||||
if transaction_id and event.transaction_id != transaction_id:
|
||||
return False
|
||||
|
||||
if event_type and event.event_type != event_type:
|
||||
return False
|
||||
|
||||
if start_time and event.timestamp < start_time:
|
||||
return False
|
||||
|
||||
if end_time and event.timestamp > end_time:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _calculate_file_hash(self, file_path: Path) -> str:
|
||||
"""Calculate SHA-256 hash of file"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
def _get_stored_hash(self, file_path: Path) -> str:
|
||||
"""Get stored hash for file"""
|
||||
hash_file = file_path.with_suffix(".hash")
|
||||
if hash_file.exists():
|
||||
with open(hash_file, "r") as f:
|
||||
return f.read().strip()
|
||||
return ""
|
||||
|
||||
|
||||
# Global audit logger instance
|
||||
audit_logger = AuditLogger()
|
||||
349
apps/coordinator-api/src/app/services/encryption.py
Normal file
349
apps/coordinator-api/src/app/services/encryption.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""
|
||||
Encryption service for confidential transactions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
|
||||
|
||||
from ..models import ConfidentialTransaction, AccessLog
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EncryptedData:
|
||||
"""Container for encrypted data and keys"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ciphertext: bytes,
|
||||
encrypted_keys: Dict[str, bytes],
|
||||
algorithm: str = "AES-256-GCM+X25519",
|
||||
nonce: Optional[bytes] = None,
|
||||
tag: Optional[bytes] = None
|
||||
):
|
||||
self.ciphertext = ciphertext
|
||||
self.encrypted_keys = encrypted_keys
|
||||
self.algorithm = algorithm
|
||||
self.nonce = nonce
|
||||
self.tag = tag
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
"ciphertext": base64.b64encode(self.ciphertext).decode(),
|
||||
"encrypted_keys": {
|
||||
participant: base64.b64encode(key).decode()
|
||||
for participant, key in self.encrypted_keys.items()
|
||||
},
|
||||
"algorithm": self.algorithm,
|
||||
"nonce": base64.b64encode(self.nonce).decode() if self.nonce else None,
|
||||
"tag": base64.b64encode(self.tag).decode() if self.tag else None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "EncryptedData":
|
||||
"""Create from dictionary"""
|
||||
return cls(
|
||||
ciphertext=base64.b64decode(data["ciphertext"]),
|
||||
encrypted_keys={
|
||||
participant: base64.b64decode(key)
|
||||
for participant, key in data["encrypted_keys"].items()
|
||||
},
|
||||
algorithm=data["algorithm"],
|
||||
nonce=base64.b64decode(data["nonce"]) if data.get("nonce") else None,
|
||||
tag=base64.b64decode(data["tag"]) if data.get("tag") else None
|
||||
)
|
||||
|
||||
|
||||
class EncryptionService:
|
||||
"""Service for encrypting/decrypting confidential transaction data"""
|
||||
|
||||
def __init__(self, key_manager: "KeyManager"):
|
||||
self.key_manager = key_manager
|
||||
self.backend = default_backend()
|
||||
self.algorithm = "AES-256-GCM+X25519"
|
||||
|
||||
def encrypt(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
participants: List[str],
|
||||
include_audit: bool = True
|
||||
) -> EncryptedData:
|
||||
"""Encrypt data for multiple participants
|
||||
|
||||
Args:
|
||||
data: Data to encrypt
|
||||
participants: List of participant IDs who can decrypt
|
||||
include_audit: Whether to include audit escrow key
|
||||
|
||||
Returns:
|
||||
EncryptedData container with ciphertext and encrypted keys
|
||||
"""
|
||||
try:
|
||||
# Generate random DEK (Data Encryption Key)
|
||||
dek = os.urandom(32) # 256-bit key for AES-256
|
||||
nonce = os.urandom(12) # 96-bit nonce for GCM
|
||||
|
||||
# Serialize and encrypt data
|
||||
plaintext = json.dumps(data, separators=(",", ":")).encode()
|
||||
aesgcm = AESGCM(dek)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
|
||||
|
||||
# Extract tag (included in ciphertext for GCM)
|
||||
tag = ciphertext[-16:]
|
||||
actual_ciphertext = ciphertext[:-16]
|
||||
|
||||
# Encrypt DEK for each participant
|
||||
encrypted_keys = {}
|
||||
for participant in participants:
|
||||
try:
|
||||
public_key = self.key_manager.get_public_key(participant)
|
||||
encrypted_dek = self._encrypt_dek(dek, public_key)
|
||||
encrypted_keys[participant] = encrypted_dek
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt DEK for participant {participant}: {e}")
|
||||
continue
|
||||
|
||||
# Add audit escrow if requested
|
||||
if include_audit:
|
||||
try:
|
||||
audit_public_key = self.key_manager.get_audit_key()
|
||||
encrypted_dek = self._encrypt_dek(dek, audit_public_key)
|
||||
encrypted_keys["audit"] = encrypted_dek
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt DEK for audit: {e}")
|
||||
|
||||
return EncryptedData(
|
||||
ciphertext=actual_ciphertext,
|
||||
encrypted_keys=encrypted_keys,
|
||||
algorithm=self.algorithm,
|
||||
nonce=nonce,
|
||||
tag=tag
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Encryption failed: {e}")
|
||||
raise EncryptionError(f"Failed to encrypt data: {e}")
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
encrypted_data: EncryptedData,
|
||||
participant_id: str,
|
||||
purpose: str = "access"
|
||||
) -> Dict[str, Any]:
|
||||
"""Decrypt data for a specific participant
|
||||
|
||||
Args:
|
||||
encrypted_data: The encrypted data container
|
||||
participant_id: ID of the participant requesting decryption
|
||||
purpose: Purpose of decryption for audit logging
|
||||
|
||||
Returns:
|
||||
Decrypted data as dictionary
|
||||
"""
|
||||
try:
|
||||
# Get participant's private key
|
||||
private_key = self.key_manager.get_private_key(participant_id)
|
||||
|
||||
# Get encrypted DEK for participant
|
||||
if participant_id not in encrypted_data.encrypted_keys:
|
||||
raise AccessDeniedError(f"Participant {participant_id} not authorized")
|
||||
|
||||
encrypted_dek = encrypted_data.encrypted_keys[participant_id]
|
||||
|
||||
# Decrypt DEK
|
||||
dek = self._decrypt_dek(encrypted_dek, private_key)
|
||||
|
||||
# Reconstruct ciphertext with tag
|
||||
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
|
||||
|
||||
# Decrypt data
|
||||
aesgcm = AESGCM(dek)
|
||||
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
|
||||
|
||||
data = json.loads(plaintext.decode())
|
||||
|
||||
# Log access
|
||||
self._log_access(
|
||||
transaction_id=None, # Will be set by caller
|
||||
participant_id=participant_id,
|
||||
purpose=purpose,
|
||||
success=True
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Decryption failed for participant {participant_id}: {e}")
|
||||
self._log_access(
|
||||
transaction_id=None,
|
||||
participant_id=participant_id,
|
||||
purpose=purpose,
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
raise DecryptionError(f"Failed to decrypt data: {e}")
|
||||
|
||||
def audit_decrypt(
|
||||
self,
|
||||
encrypted_data: EncryptedData,
|
||||
audit_authorization: str,
|
||||
purpose: str = "audit"
|
||||
) -> Dict[str, Any]:
|
||||
"""Decrypt data for audit purposes
|
||||
|
||||
Args:
|
||||
encrypted_data: The encrypted data container
|
||||
audit_authorization: Authorization token for audit access
|
||||
purpose: Purpose of decryption
|
||||
|
||||
Returns:
|
||||
Decrypted data as dictionary
|
||||
"""
|
||||
try:
|
||||
# Verify audit authorization
|
||||
if not self.key_manager.verify_audit_authorization(audit_authorization):
|
||||
raise AccessDeniedError("Invalid audit authorization")
|
||||
|
||||
# Get audit private key
|
||||
audit_private_key = self.key_manager.get_audit_private_key(audit_authorization)
|
||||
|
||||
# Decrypt using audit key
|
||||
if "audit" not in encrypted_data.encrypted_keys:
|
||||
raise AccessDeniedError("Audit escrow not available")
|
||||
|
||||
encrypted_dek = encrypted_data.encrypted_keys["audit"]
|
||||
dek = self._decrypt_dek(encrypted_dek, audit_private_key)
|
||||
|
||||
# Decrypt data
|
||||
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
|
||||
aesgcm = AESGCM(dek)
|
||||
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
|
||||
|
||||
data = json.loads(plaintext.decode())
|
||||
|
||||
# Log audit access
|
||||
self._log_access(
|
||||
transaction_id=None,
|
||||
participant_id="audit",
|
||||
purpose=f"audit:{purpose}",
|
||||
success=True,
|
||||
authorization=audit_authorization
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Audit decryption failed: {e}")
|
||||
raise DecryptionError(f"Failed to decrypt for audit: {e}")
|
||||
|
||||
def _encrypt_dek(self, dek: bytes, public_key: X25519PublicKey) -> bytes:
|
||||
"""Encrypt DEK using ECIES with X25519"""
|
||||
# Generate ephemeral key pair
|
||||
ephemeral_private = X25519PrivateKey.generate()
|
||||
ephemeral_public = ephemeral_private.public_key()
|
||||
|
||||
# Perform ECDH
|
||||
shared_key = ephemeral_private.exchange(public_key)
|
||||
|
||||
# Derive encryption key from shared secret
|
||||
derived_key = HKDF(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=None,
|
||||
info=b"AITBC-DEK-Encryption",
|
||||
backend=self.backend
|
||||
).derive(shared_key)
|
||||
|
||||
# Encrypt DEK with AES-GCM
|
||||
aesgcm = AESGCM(derived_key)
|
||||
nonce = os.urandom(12)
|
||||
encrypted_dek = aesgcm.encrypt(nonce, dek, None)
|
||||
|
||||
# Return ephemeral public key + nonce + encrypted DEK
|
||||
return (
|
||||
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw) +
|
||||
nonce +
|
||||
encrypted_dek
|
||||
)
|
||||
|
||||
def _decrypt_dek(self, encrypted_dek: bytes, private_key: X25519PrivateKey) -> bytes:
|
||||
"""Decrypt DEK using ECIES with X25519"""
|
||||
# Extract components
|
||||
ephemeral_public_bytes = encrypted_dek[:32]
|
||||
nonce = encrypted_dek[32:44]
|
||||
dek_ciphertext = encrypted_dek[44:]
|
||||
|
||||
# Reconstruct ephemeral public key
|
||||
ephemeral_public = X25519PublicKey.from_public_bytes(ephemeral_public_bytes)
|
||||
|
||||
# Perform ECDH
|
||||
shared_key = private_key.exchange(ephemeral_public)
|
||||
|
||||
# Derive decryption key
|
||||
derived_key = HKDF(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=None,
|
||||
info=b"AITBC-DEK-Encryption",
|
||||
backend=self.backend
|
||||
).derive(shared_key)
|
||||
|
||||
# Decrypt DEK
|
||||
aesgcm = AESGCM(derived_key)
|
||||
dek = aesgcm.decrypt(nonce, dek_ciphertext, None)
|
||||
|
||||
return dek
|
||||
|
||||
def _log_access(
|
||||
self,
|
||||
transaction_id: Optional[str],
|
||||
participant_id: str,
|
||||
purpose: str,
|
||||
success: bool,
|
||||
error: Optional[str] = None,
|
||||
authorization: Optional[str] = None
|
||||
):
|
||||
"""Log access to confidential data"""
|
||||
try:
|
||||
log_entry = {
|
||||
"transaction_id": transaction_id,
|
||||
"participant_id": participant_id,
|
||||
"purpose": purpose,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"success": success,
|
||||
"error": error,
|
||||
"authorization": authorization
|
||||
}
|
||||
|
||||
# In production, this would go to secure audit log
|
||||
logger.info(f"Confidential data access: {json.dumps(log_entry)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log access: {e}")
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Base exception for encryption errors"""
|
||||
pass
|
||||
|
||||
|
||||
class DecryptionError(EncryptionError):
|
||||
"""Exception for decryption errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AccessDeniedError(EncryptionError):
|
||||
"""Exception for access denied errors"""
|
||||
pass
|
||||
435
apps/coordinator-api/src/app/services/hsm_key_manager.py
Normal file
435
apps/coordinator-api/src/app/services/hsm_key_manager.py
Normal file
@ -0,0 +1,435 @@
|
||||
"""
|
||||
HSM-backed key management for production use
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from ..models import KeyPair, KeyRotationLog, AuditAuthorization
|
||||
from ..repositories.confidential import (
|
||||
ParticipantKeyRepository,
|
||||
KeyRotationRepository
|
||||
)
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HSMProvider(ABC):
|
||||
"""Abstract base class for HSM providers"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate key pair in HSM, return (public_key, key_handle)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
|
||||
"""Sign data with HSM-stored private key"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
|
||||
"""Derive shared secret using ECDH"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_key(self, key_handle: bytes) -> bool:
|
||||
"""Delete key from HSM"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_keys(self) -> List[str]:
|
||||
"""List all key IDs in HSM"""
|
||||
pass
|
||||
|
||||
|
||||
class SoftwareHSMProvider(HSMProvider):
|
||||
"""Software-based HSM provider for development/testing"""
|
||||
|
||||
def __init__(self):
|
||||
self._keys: Dict[str, X25519PrivateKey] = {}
|
||||
self._backend = default_backend()
|
||||
|
||||
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate key pair in memory"""
|
||||
private_key = X25519PrivateKey.generate()
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Store private key (in production, this would be in secure hardware)
|
||||
self._keys[key_id] = private_key
|
||||
|
||||
return (
|
||||
public_key.public_bytes(Encoding.Raw, PublicFormat.Raw),
|
||||
key_id.encode() # Use key_id as handle
|
||||
)
|
||||
|
||||
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
|
||||
"""Sign with stored private key"""
|
||||
key_id = key_handle.decode()
|
||||
private_key = self._keys.get(key_id)
|
||||
|
||||
if not private_key:
|
||||
raise ValueError(f"Key not found: {key_id}")
|
||||
|
||||
# For X25519, we don't sign - we exchange
|
||||
# This is a placeholder for actual HSM operations
|
||||
return b"signature_placeholder"
|
||||
|
||||
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
|
||||
"""Derive shared secret"""
|
||||
key_id = key_handle.decode()
|
||||
private_key = self._keys.get(key_id)
|
||||
|
||||
if not private_key:
|
||||
raise ValueError(f"Key not found: {key_id}")
|
||||
|
||||
peer_public = X25519PublicKey.from_public_bytes(public_key)
|
||||
return private_key.exchange(peer_public)
|
||||
|
||||
async def delete_key(self, key_handle: bytes) -> bool:
|
||||
"""Delete key from memory"""
|
||||
key_id = key_handle.decode()
|
||||
if key_id in self._keys:
|
||||
del self._keys[key_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def list_keys(self) -> List[str]:
|
||||
"""List all keys"""
|
||||
return list(self._keys.keys())
|
||||
|
||||
|
||||
class AzureKeyVaultProvider(HSMProvider):
|
||||
"""Azure Key Vault HSM provider for production"""
|
||||
|
||||
def __init__(self, vault_url: str, credential):
|
||||
from azure.keyvault.keys.crypto import CryptographyClient
|
||||
from azure.keyvault.keys import KeyClient
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
self.vault_url = vault_url
|
||||
self.credential = credential or DefaultAzureCredential()
|
||||
self.key_client = KeyClient(vault_url, self.credential)
|
||||
self.crypto_client = None
|
||||
|
||||
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate key in Azure Key Vault"""
|
||||
# Create EC-HSM key
|
||||
key = await self.key_client.create_ec_key(
|
||||
key_id,
|
||||
curve="P-256" # Azure doesn't support X25519 directly
|
||||
)
|
||||
|
||||
# Get public key
|
||||
public_key = key.key.cryptography_client.public_key()
|
||||
public_bytes = public_key.public_bytes(
|
||||
Encoding.Raw,
|
||||
PublicFormat.Raw
|
||||
)
|
||||
|
||||
return public_bytes, key.id.encode()
|
||||
|
||||
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
|
||||
"""Sign with Azure Key Vault"""
|
||||
key_id = key_handle.decode()
|
||||
crypto_client = self.key_client.get_cryptography_client(key_id)
|
||||
|
||||
sign_result = await crypto_client.sign("ES256", data)
|
||||
return sign_result.signature
|
||||
|
||||
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
|
||||
"""Derive shared secret (not directly supported in Azure)"""
|
||||
# Would need to use a different approach
|
||||
raise NotImplementedError("ECDH not supported in Azure Key Vault")
|
||||
|
||||
async def delete_key(self, key_handle: bytes) -> bool:
|
||||
"""Delete key from Azure Key Vault"""
|
||||
key_name = key_handle.decode().split("/")[-1]
|
||||
await self.key_client.begin_delete_key(key_name)
|
||||
return True
|
||||
|
||||
async def list_keys(self) -> List[str]:
|
||||
"""List keys in Azure Key Vault"""
|
||||
keys = []
|
||||
async for key in self.key_client.list_properties_of_keys():
|
||||
keys.append(key.name)
|
||||
return keys
|
||||
|
||||
|
||||
class AWSKMSProvider(HSMProvider):
|
||||
"""AWS KMS HSM provider for production"""
|
||||
|
||||
def __init__(self, region_name: str):
|
||||
import boto3
|
||||
self.kms = boto3.client('kms', region_name=region_name)
|
||||
|
||||
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate key pair in AWS KMS"""
|
||||
# Create CMK
|
||||
response = self.kms.create_key(
|
||||
Description=f"AITBC confidential transaction key for {key_id}",
|
||||
KeyUsage='ENCRYPT_DECRYPT',
|
||||
KeySpec='ECC_NIST_P256'
|
||||
)
|
||||
|
||||
# Get public key
|
||||
public_key = self.kms.get_public_key(KeyId=response['KeyMetadata']['KeyId'])
|
||||
|
||||
return public_key['PublicKey'], response['KeyMetadata']['KeyId'].encode()
|
||||
|
||||
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
|
||||
"""Sign with AWS KMS"""
|
||||
response = self.kms.sign(
|
||||
KeyId=key_handle.decode(),
|
||||
Message=data,
|
||||
MessageType='RAW',
|
||||
SigningAlgorithm='ECDSA_SHA_256'
|
||||
)
|
||||
return response['Signature']
|
||||
|
||||
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
|
||||
"""Derive shared secret (not directly supported in KMS)"""
|
||||
raise NotImplementedError("ECDH not supported in AWS KMS")
|
||||
|
||||
async def delete_key(self, key_handle: bytes) -> bool:
|
||||
"""Schedule key deletion in AWS KMS"""
|
||||
self.kms.schedule_key_deletion(KeyId=key_handle.decode())
|
||||
return True
|
||||
|
||||
async def list_keys(self) -> List[str]:
|
||||
"""List keys in AWS KMS"""
|
||||
keys = []
|
||||
paginator = self.kms.get_paginator('list_keys')
|
||||
for page in paginator.paginate():
|
||||
for key in page['Keys']:
|
||||
keys.append(key['KeyId'])
|
||||
return keys
|
||||
|
||||
|
||||
class HSMKeyManager:
|
||||
"""HSM-backed key manager for production"""
|
||||
|
||||
def __init__(self, hsm_provider: HSMProvider, key_repository: ParticipantKeyRepository):
|
||||
self.hsm = hsm_provider
|
||||
self.key_repo = key_repository
|
||||
self._master_key = None
|
||||
self._init_master_key()
|
||||
|
||||
def _init_master_key(self):
|
||||
"""Initialize master key for encrypting stored data"""
|
||||
# In production, this would come from HSM or KMS
|
||||
self._master_key = os.urandom(32)
|
||||
|
||||
async def generate_key_pair(self, participant_id: str) -> KeyPair:
|
||||
"""Generate key pair in HSM"""
|
||||
try:
|
||||
# Generate key in HSM
|
||||
hsm_key_id = f"aitbc-{participant_id}-{datetime.utcnow().timestamp()}"
|
||||
public_key_bytes, key_handle = await self.hsm.generate_key(hsm_key_id)
|
||||
|
||||
# Create key pair record
|
||||
key_pair = KeyPair(
|
||||
participant_id=participant_id,
|
||||
private_key=key_handle, # Store HSM handle, not actual private key
|
||||
public_key=public_key_bytes,
|
||||
algorithm="X25519",
|
||||
created_at=datetime.utcnow(),
|
||||
version=1
|
||||
)
|
||||
|
||||
# Store metadata in database
|
||||
await self.key_repo.create(
|
||||
await self._get_session(),
|
||||
key_pair
|
||||
)
|
||||
|
||||
logger.info(f"Generated HSM key pair for participant: {participant_id}")
|
||||
return key_pair
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate HSM key pair for {participant_id}: {e}")
|
||||
raise
|
||||
|
||||
async def rotate_keys(self, participant_id: str) -> KeyPair:
|
||||
"""Rotate keys in HSM"""
|
||||
# Get current key
|
||||
current_key = await self.key_repo.get_by_participant(
|
||||
await self._get_session(),
|
||||
participant_id
|
||||
)
|
||||
|
||||
if not current_key:
|
||||
raise ValueError(f"No existing keys for {participant_id}")
|
||||
|
||||
# Generate new key
|
||||
new_key_pair = await self.generate_key_pair(participant_id)
|
||||
|
||||
# Log rotation
|
||||
rotation_log = KeyRotationLog(
|
||||
participant_id=participant_id,
|
||||
old_version=current_key.version,
|
||||
new_version=new_key_pair.version,
|
||||
rotated_at=datetime.utcnow(),
|
||||
reason="scheduled_rotation"
|
||||
)
|
||||
|
||||
await self.key_repo.rotate(
|
||||
await self._get_session(),
|
||||
participant_id,
|
||||
new_key_pair
|
||||
)
|
||||
|
||||
# Delete old key from HSM
|
||||
await self.hsm.delete_key(current_key.private_key)
|
||||
|
||||
return new_key_pair
|
||||
|
||||
def get_public_key(self, participant_id: str) -> X25519PublicKey:
|
||||
"""Get public key for participant"""
|
||||
key = self.key_repo.get_by_participant_sync(participant_id)
|
||||
if not key:
|
||||
raise ValueError(f"No keys found for {participant_id}")
|
||||
|
||||
return X25519PublicKey.from_public_bytes(key.public_key)
|
||||
|
||||
async def get_private_key_handle(self, participant_id: str) -> bytes:
|
||||
"""Get HSM key handle for participant"""
|
||||
key = await self.key_repo.get_by_participant(
|
||||
await self._get_session(),
|
||||
participant_id
|
||||
)
|
||||
|
||||
if not key:
|
||||
raise ValueError(f"No keys found for {participant_id}")
|
||||
|
||||
return key.private_key # This is the HSM handle
|
||||
|
||||
async def derive_shared_secret(
|
||||
self,
|
||||
participant_id: str,
|
||||
peer_public_key: bytes
|
||||
) -> bytes:
|
||||
"""Derive shared secret using HSM"""
|
||||
key_handle = await self.get_private_key_handle(participant_id)
|
||||
return await self.hsm.derive_shared_secret(key_handle, peer_public_key)
|
||||
|
||||
async def sign_with_key(
|
||||
self,
|
||||
participant_id: str,
|
||||
data: bytes
|
||||
) -> bytes:
|
||||
"""Sign data using HSM-stored key"""
|
||||
key_handle = await self.get_private_key_handle(participant_id)
|
||||
return await self.hsm.sign_with_key(key_handle, data)
|
||||
|
||||
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
|
||||
"""Revoke participant's keys"""
|
||||
# Get current key
|
||||
current_key = await self.key_repo.get_by_participant(
|
||||
await self._get_session(),
|
||||
participant_id
|
||||
)
|
||||
|
||||
if not current_key:
|
||||
return False
|
||||
|
||||
# Delete from HSM
|
||||
await self.hsm.delete_key(current_key.private_key)
|
||||
|
||||
# Mark as revoked in database
|
||||
return await self.key_repo.update_active(
|
||||
await self._get_session(),
|
||||
participant_id,
|
||||
False,
|
||||
reason
|
||||
)
|
||||
|
||||
async def create_audit_authorization(
|
||||
self,
|
||||
issuer: str,
|
||||
purpose: str,
|
||||
expires_in_hours: int = 24
|
||||
) -> str:
|
||||
"""Create audit authorization signed with HSM"""
|
||||
# Create authorization payload
|
||||
payload = {
|
||||
"issuer": issuer,
|
||||
"subject": "audit_access",
|
||||
"purpose": purpose,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"expires_at": (datetime.utcnow() + timedelta(hours=expires_in_hours)).isoformat()
|
||||
}
|
||||
|
||||
# Sign with audit key
|
||||
audit_key_handle = await self.get_private_key_handle("audit")
|
||||
signature = await self.hsm.sign_with_key(
|
||||
audit_key_handle,
|
||||
json.dumps(payload).encode()
|
||||
)
|
||||
|
||||
payload["signature"] = signature.hex()
|
||||
|
||||
# Encode for transport
|
||||
import base64
|
||||
return base64.b64encode(json.dumps(payload).encode()).decode()
|
||||
|
||||
async def verify_audit_authorization(self, authorization: str) -> bool:
|
||||
"""Verify audit authorization"""
|
||||
try:
|
||||
# Decode authorization
|
||||
import base64
|
||||
auth_data = base64.b64decode(authorization).decode()
|
||||
auth_json = json.loads(auth_data)
|
||||
|
||||
# Check expiration
|
||||
expires_at = datetime.fromisoformat(auth_json["expires_at"])
|
||||
if datetime.utcnow() > expires_at:
|
||||
return False
|
||||
|
||||
# Verify signature with audit public key
|
||||
audit_public_key = self.get_public_key("audit")
|
||||
# In production, verify with proper cryptographic library
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify audit authorization: {e}")
|
||||
return False
|
||||
|
||||
async def _get_session(self):
|
||||
"""Get database session"""
|
||||
# In production, inject via dependency injection
|
||||
async for session in get_async_session():
|
||||
return session
|
||||
|
||||
|
||||
def create_hsm_key_manager() -> HSMKeyManager:
|
||||
"""Create HSM key manager based on configuration"""
|
||||
from ..repositories.confidential import ParticipantKeyRepository
|
||||
|
||||
# Get HSM provider from settings
|
||||
hsm_type = getattr(settings, 'HSM_PROVIDER', 'software')
|
||||
|
||||
if hsm_type == 'software':
|
||||
hsm = SoftwareHSMProvider()
|
||||
elif hsm_type == 'azure':
|
||||
vault_url = getattr(settings, 'AZURE_KEY_VAULT_URL')
|
||||
hsm = AzureKeyVaultProvider(vault_url)
|
||||
elif hsm_type == 'aws':
|
||||
region = getattr(settings, 'AWS_REGION', 'us-east-1')
|
||||
hsm = AWSKMSProvider(region)
|
||||
else:
|
||||
raise ValueError(f"Unknown HSM provider: {hsm_type}")
|
||||
|
||||
key_repo = ParticipantKeyRepository()
|
||||
|
||||
return HSMKeyManager(hsm, key_repo)
|
||||
466
apps/coordinator-api/src/app/services/key_management.py
Normal file
466
apps/coordinator-api/src/app/services/key_management.py
Normal file
@ -0,0 +1,466 @@
|
||||
"""
|
||||
Key management service for confidential transactions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ..models import KeyPair, KeyRotationLog, AuditAuthorization
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KeyManager:
|
||||
"""Manages encryption keys for confidential transactions"""
|
||||
|
||||
def __init__(self, storage_backend: "KeyStorageBackend"):
|
||||
self.storage = storage_backend
|
||||
self.backend = default_backend()
|
||||
self._key_cache = {}
|
||||
self._audit_key = None
|
||||
self._audit_key_rotation = timedelta(days=30)
|
||||
|
||||
async def generate_key_pair(self, participant_id: str) -> KeyPair:
|
||||
"""Generate X25519 key pair for participant"""
|
||||
try:
|
||||
# Generate new key pair
|
||||
private_key = X25519PrivateKey.generate()
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Create key pair object
|
||||
key_pair = KeyPair(
|
||||
participant_id=participant_id,
|
||||
private_key=private_key.private_bytes_raw(),
|
||||
public_key=public_key.public_bytes_raw(),
|
||||
algorithm="X25519",
|
||||
created_at=datetime.utcnow(),
|
||||
version=1
|
||||
)
|
||||
|
||||
# Store securely
|
||||
await self.storage.store_key_pair(key_pair)
|
||||
|
||||
# Cache public key
|
||||
self._key_cache[participant_id] = {
|
||||
"public_key": public_key,
|
||||
"version": key_pair.version
|
||||
}
|
||||
|
||||
logger.info(f"Generated key pair for participant: {participant_id}")
|
||||
return key_pair
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate key pair for {participant_id}: {e}")
|
||||
raise KeyManagementError(f"Key generation failed: {e}")
|
||||
|
||||
async def rotate_keys(self, participant_id: str) -> KeyPair:
|
||||
"""Rotate encryption keys for participant"""
|
||||
try:
|
||||
# Get current key pair
|
||||
current_key = await self.storage.get_key_pair(participant_id)
|
||||
if not current_key:
|
||||
raise KeyNotFoundError(f"No existing keys for {participant_id}")
|
||||
|
||||
# Generate new key pair
|
||||
new_key_pair = await self.generate_key_pair(participant_id)
|
||||
|
||||
# Log rotation
|
||||
rotation_log = KeyRotationLog(
|
||||
participant_id=participant_id,
|
||||
old_version=current_key.version,
|
||||
new_version=new_key_pair.version,
|
||||
rotated_at=datetime.utcnow(),
|
||||
reason="scheduled_rotation"
|
||||
)
|
||||
await self.storage.log_rotation(rotation_log)
|
||||
|
||||
# Re-encrypt active transactions (in production)
|
||||
await self._reencrypt_transactions(participant_id, current_key, new_key_pair)
|
||||
|
||||
logger.info(f"Rotated keys for participant: {participant_id}")
|
||||
return new_key_pair
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rotate keys for {participant_id}: {e}")
|
||||
raise KeyManagementError(f"Key rotation failed: {e}")
|
||||
|
||||
def get_public_key(self, participant_id: str) -> X25519PublicKey:
|
||||
"""Get public key for participant"""
|
||||
# Check cache first
|
||||
if participant_id in self._key_cache:
|
||||
return self._key_cache[participant_id]["public_key"]
|
||||
|
||||
# Load from storage
|
||||
key_pair = self.storage.get_key_pair_sync(participant_id)
|
||||
if not key_pair:
|
||||
raise KeyNotFoundError(f"No keys found for participant: {participant_id}")
|
||||
|
||||
# Reconstruct public key
|
||||
public_key = X25519PublicKey.from_public_bytes(key_pair.public_key)
|
||||
|
||||
# Cache it
|
||||
self._key_cache[participant_id] = {
|
||||
"public_key": public_key,
|
||||
"version": key_pair.version
|
||||
}
|
||||
|
||||
return public_key
|
||||
|
||||
def get_private_key(self, participant_id: str) -> X25519PrivateKey:
|
||||
"""Get private key for participant (from secure storage)"""
|
||||
key_pair = self.storage.get_key_pair_sync(participant_id)
|
||||
if not key_pair:
|
||||
raise KeyNotFoundError(f"No keys found for participant: {participant_id}")
|
||||
|
||||
# Reconstruct private key
|
||||
private_key = X25519PrivateKey.from_private_bytes(key_pair.private_key)
|
||||
return private_key
|
||||
|
||||
async def get_audit_key(self) -> X25519PublicKey:
|
||||
"""Get public audit key for escrow"""
|
||||
if not self._audit_key or self._should_rotate_audit_key():
|
||||
await self._rotate_audit_key()
|
||||
|
||||
return self._audit_key
|
||||
|
||||
async def get_audit_private_key(self, authorization: str) -> X25519PrivateKey:
|
||||
"""Get private audit key with authorization"""
|
||||
# Verify authorization
|
||||
if not await self.verify_audit_authorization(authorization):
|
||||
raise AccessDeniedError("Invalid audit authorization")
|
||||
|
||||
# Load audit key from secure storage
|
||||
audit_key_data = await self.storage.get_audit_key()
|
||||
if not audit_key_data:
|
||||
raise KeyNotFoundError("Audit key not found")
|
||||
|
||||
return X25519PrivateKey.from_private_bytes(audit_key_data.private_key)
|
||||
|
||||
async def verify_audit_authorization(self, authorization: str) -> bool:
|
||||
"""Verify audit authorization token"""
|
||||
try:
|
||||
# Decode authorization
|
||||
auth_data = base64.b64decode(authorization).decode()
|
||||
auth_json = json.loads(auth_data)
|
||||
|
||||
# Check expiration
|
||||
expires_at = datetime.fromisoformat(auth_json["expires_at"])
|
||||
if datetime.utcnow() > expires_at:
|
||||
return False
|
||||
|
||||
# Verify signature (in production, use proper signature verification)
|
||||
# For now, just check format
|
||||
required_fields = ["issuer", "subject", "expires_at", "signature"]
|
||||
return all(field in auth_json for field in required_fields)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify audit authorization: {e}")
|
||||
return False
|
||||
|
||||
async def create_audit_authorization(
|
||||
self,
|
||||
issuer: str,
|
||||
purpose: str,
|
||||
expires_in_hours: int = 24
|
||||
) -> str:
|
||||
"""Create audit authorization token"""
|
||||
try:
|
||||
# Create authorization payload
|
||||
payload = {
|
||||
"issuer": issuer,
|
||||
"subject": "audit_access",
|
||||
"purpose": purpose,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"expires_at": (datetime.utcnow() + timedelta(hours=expires_in_hours)).isoformat(),
|
||||
"signature": "placeholder" # In production, sign with issuer key
|
||||
}
|
||||
|
||||
# Encode and return
|
||||
auth_json = json.dumps(payload)
|
||||
return base64.b64encode(auth_json.encode()).decode()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create audit authorization: {e}")
|
||||
raise KeyManagementError(f"Authorization creation failed: {e}")
|
||||
|
||||
async def list_participants(self) -> List[str]:
|
||||
"""List all participants with keys"""
|
||||
return await self.storage.list_participants()
|
||||
|
||||
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
|
||||
"""Revoke participant's keys"""
|
||||
try:
|
||||
# Mark keys as revoked
|
||||
success = await self.storage.revoke_keys(participant_id, reason)
|
||||
|
||||
if success:
|
||||
# Clear cache
|
||||
if participant_id in self._key_cache:
|
||||
del self._key_cache[participant_id]
|
||||
|
||||
logger.info(f"Revoked keys for participant: {participant_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke keys for {participant_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _rotate_audit_key(self):
|
||||
"""Rotate the audit escrow key"""
|
||||
try:
|
||||
# Generate new audit key pair
|
||||
audit_private = X25519PrivateKey.generate()
|
||||
audit_public = audit_private.public_key()
|
||||
|
||||
# Store securely
|
||||
audit_key_pair = KeyPair(
|
||||
participant_id="audit",
|
||||
private_key=audit_private.private_bytes_raw(),
|
||||
public_key=audit_public.public_bytes_raw(),
|
||||
algorithm="X25519",
|
||||
created_at=datetime.utcnow(),
|
||||
version=1
|
||||
)
|
||||
|
||||
await self.storage.store_audit_key(audit_key_pair)
|
||||
self._audit_key = audit_public
|
||||
|
||||
logger.info("Rotated audit escrow key")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rotate audit key: {e}")
|
||||
raise KeyManagementError(f"Audit key rotation failed: {e}")
|
||||
|
||||
def _should_rotate_audit_key(self) -> bool:
|
||||
"""Check if audit key needs rotation"""
|
||||
# In production, check last rotation time
|
||||
return self._audit_key is None
|
||||
|
||||
async def _reencrypt_transactions(
|
||||
self,
|
||||
participant_id: str,
|
||||
old_key_pair: KeyPair,
|
||||
new_key_pair: KeyPair
|
||||
):
|
||||
"""Re-encrypt active transactions with new key"""
|
||||
# This would be implemented in production
|
||||
# For now, just log the action
|
||||
logger.info(f"Would re-encrypt transactions for {participant_id}")
|
||||
pass
|
||||
|
||||
|
||||
class KeyStorageBackend:
|
||||
"""Abstract base for key storage backends"""
|
||||
|
||||
async def store_key_pair(self, key_pair: KeyPair) -> bool:
|
||||
"""Store key pair securely"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_key_pair(self, participant_id: str) -> Optional[KeyPair]:
|
||||
"""Get key pair for participant"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_key_pair_sync(self, participant_id: str) -> Optional[KeyPair]:
|
||||
"""Synchronous get key pair"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def store_audit_key(self, key_pair: KeyPair) -> bool:
|
||||
"""Store audit key pair"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_audit_key(self) -> Optional[KeyPair]:
|
||||
"""Get audit key pair"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_participants(self) -> List[str]:
|
||||
"""List all participants"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
|
||||
"""Revoke keys for participant"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def log_rotation(self, rotation_log: KeyRotationLog) -> bool:
|
||||
"""Log key rotation"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileKeyStorage(KeyStorageBackend):
|
||||
"""File-based key storage for development"""
|
||||
|
||||
def __init__(self, storage_path: str):
|
||||
self.storage_path = storage_path
|
||||
os.makedirs(storage_path, exist_ok=True)
|
||||
|
||||
async def store_key_pair(self, key_pair: KeyPair) -> bool:
|
||||
"""Store key pair to file"""
|
||||
try:
|
||||
file_path = os.path.join(self.storage_path, f"{key_pair.participant_id}.json")
|
||||
|
||||
# Store private key in separate encrypted file
|
||||
private_path = os.path.join(self.storage_path, f"{key_pair.participant_id}.priv")
|
||||
|
||||
# In production, encrypt private key with master key
|
||||
with open(private_path, "wb") as f:
|
||||
f.write(key_pair.private_key)
|
||||
|
||||
# Store public metadata
|
||||
metadata = {
|
||||
"participant_id": key_pair.participant_id,
|
||||
"public_key": base64.b64encode(key_pair.public_key).decode(),
|
||||
"algorithm": key_pair.algorithm,
|
||||
"created_at": key_pair.created_at.isoformat(),
|
||||
"version": key_pair.version
|
||||
}
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store key pair: {e}")
|
||||
return False
|
||||
|
||||
async def get_key_pair(self, participant_id: str) -> Optional[KeyPair]:
|
||||
"""Get key pair from file"""
|
||||
return self.get_key_pair_sync(participant_id)
|
||||
|
||||
def get_key_pair_sync(self, participant_id: str) -> Optional[KeyPair]:
|
||||
"""Synchronous get key pair"""
|
||||
try:
|
||||
file_path = os.path.join(self.storage_path, f"{participant_id}.json")
|
||||
private_path = os.path.join(self.storage_path, f"{participant_id}.priv")
|
||||
|
||||
if not os.path.exists(file_path) or not os.path.exists(private_path):
|
||||
return None
|
||||
|
||||
# Load metadata
|
||||
with open(file_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Load private key
|
||||
with open(private_path, "rb") as f:
|
||||
private_key = f.read()
|
||||
|
||||
return KeyPair(
|
||||
participant_id=metadata["participant_id"],
|
||||
private_key=private_key,
|
||||
public_key=base64.b64decode(metadata["public_key"]),
|
||||
algorithm=metadata["algorithm"],
|
||||
created_at=datetime.fromisoformat(metadata["created_at"]),
|
||||
version=metadata["version"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key pair: {e}")
|
||||
return None
|
||||
|
||||
async def store_audit_key(self, key_pair: KeyPair) -> bool:
|
||||
"""Store audit key"""
|
||||
audit_path = os.path.join(self.storage_path, "audit.json")
|
||||
audit_priv_path = os.path.join(self.storage_path, "audit.priv")
|
||||
|
||||
try:
|
||||
# Store private key
|
||||
with open(audit_priv_path, "wb") as f:
|
||||
f.write(key_pair.private_key)
|
||||
|
||||
# Store metadata
|
||||
metadata = {
|
||||
"participant_id": "audit",
|
||||
"public_key": base64.b64encode(key_pair.public_key).decode(),
|
||||
"algorithm": key_pair.algorithm,
|
||||
"created_at": key_pair.created_at.isoformat(),
|
||||
"version": key_pair.version
|
||||
}
|
||||
|
||||
with open(audit_path, "w") as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store audit key: {e}")
|
||||
return False
|
||||
|
||||
async def get_audit_key(self) -> Optional[KeyPair]:
|
||||
"""Get audit key"""
|
||||
return self.get_key_pair_sync("audit")
|
||||
|
||||
async def list_participants(self) -> List[str]:
|
||||
"""List all participants"""
|
||||
participants = []
|
||||
for file in os.listdir(self.storage_path):
|
||||
if file.endswith(".json") and file != "audit.json":
|
||||
participant_id = file[:-5] # Remove .json
|
||||
participants.append(participant_id)
|
||||
return participants
|
||||
|
||||
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
|
||||
"""Revoke keys by deleting files"""
|
||||
try:
|
||||
file_path = os.path.join(self.storage_path, f"{participant_id}.json")
|
||||
private_path = os.path.join(self.storage_path, f"{participant_id}.priv")
|
||||
|
||||
# Move to revoked folder instead of deleting
|
||||
revoked_path = os.path.join(self.storage_path, "revoked")
|
||||
os.makedirs(revoked_path, exist_ok=True)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
os.rename(file_path, os.path.join(revoked_path, f"{participant_id}.json"))
|
||||
if os.path.exists(private_path):
|
||||
os.rename(private_path, os.path.join(revoked_path, f"{participant_id}.priv"))
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke keys: {e}")
|
||||
return False
|
||||
|
||||
async def log_rotation(self, rotation_log: KeyRotationLog) -> bool:
|
||||
"""Log key rotation"""
|
||||
log_path = os.path.join(self.storage_path, "rotations.log")
|
||||
|
||||
try:
|
||||
with open(log_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"participant_id": rotation_log.participant_id,
|
||||
"old_version": rotation_log.old_version,
|
||||
"new_version": rotation_log.new_version,
|
||||
"rotated_at": rotation_log.rotated_at.isoformat(),
|
||||
"reason": rotation_log.reason
|
||||
}) + "\n")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log rotation: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class KeyManagementError(Exception):
|
||||
"""Base exception for key management errors"""
|
||||
pass
|
||||
|
||||
|
||||
class KeyNotFoundError(KeyManagementError):
|
||||
"""Raised when key is not found"""
|
||||
pass
|
||||
|
||||
|
||||
class AccessDeniedError(KeyManagementError):
|
||||
"""Raised when access is denied"""
|
||||
pass
|
||||
526
apps/coordinator-api/src/app/services/quota_enforcement.py
Normal file
526
apps/coordinator-api/src/app/services/quota_enforcement.py
Normal file
@ -0,0 +1,526 @@
|
||||
"""
|
||||
Resource quota enforcement service for multi-tenant AITBC coordinator
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, update, and_, func
|
||||
from contextlib import asynccontextmanager
|
||||
import redis
|
||||
import json
|
||||
|
||||
from ..models.multitenant import TenantQuota, UsageRecord, Tenant
|
||||
from ..exceptions import QuotaExceededError, TenantError
|
||||
from ..middleware.tenant_context import get_current_tenant_id
|
||||
|
||||
|
||||
class QuotaEnforcementService:
|
||||
"""Service for enforcing tenant resource quotas"""
|
||||
|
||||
def __init__(self, db: Session, redis_client: Optional[redis.Redis] = None):
|
||||
self.db = db
|
||||
self.redis = redis_client
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Cache for quota lookups
|
||||
self._quota_cache = {}
|
||||
self._cache_ttl = 300 # 5 minutes
|
||||
|
||||
async def check_quota(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: float,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Check if tenant has sufficient quota for a resource"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Get current quota and usage
|
||||
quota = await self._get_current_quota(tenant_id, resource_type)
|
||||
|
||||
if not quota:
|
||||
# No quota set, check if unlimited plan
|
||||
tenant = await self._get_tenant(tenant_id)
|
||||
if tenant and tenant.plan in ["enterprise", "unlimited"]:
|
||||
return True
|
||||
raise QuotaExceededError(f"No quota configured for {resource_type}")
|
||||
|
||||
# Check if adding quantity would exceed limit
|
||||
current_usage = await self._get_current_usage(tenant_id, resource_type)
|
||||
|
||||
if current_usage + quantity > quota.limit_value:
|
||||
# Log quota exceeded
|
||||
self.logger.warning(
|
||||
f"Quota exceeded for tenant {tenant_id}: "
|
||||
f"{resource_type} {current_usage + quantity}/{quota.limit_value}"
|
||||
)
|
||||
|
||||
raise QuotaExceededError(
|
||||
f"Quota exceeded for {resource_type}: "
|
||||
f"{current_usage + quantity}/{quota.limit_value}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def consume_quota(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: float,
|
||||
resource_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> UsageRecord:
|
||||
"""Consume quota and record usage"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Check quota first
|
||||
await self.check_quota(resource_type, quantity, tenant_id)
|
||||
|
||||
# Create usage record
|
||||
usage_record = UsageRecord(
|
||||
tenant_id=tenant_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
quantity=quantity,
|
||||
unit=self._get_unit_for_resource(resource_type),
|
||||
unit_price=await self._get_unit_price(resource_type),
|
||||
total_cost=await self._calculate_cost(resource_type, quantity),
|
||||
currency="USD",
|
||||
usage_start=datetime.utcnow(),
|
||||
usage_end=datetime.utcnow(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
self.db.add(usage_record)
|
||||
|
||||
# Update quota usage
|
||||
await self._update_quota_usage(tenant_id, resource_type, quantity)
|
||||
|
||||
# Update cache
|
||||
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
|
||||
if self.redis:
|
||||
current = self.redis.get(cache_key)
|
||||
if current:
|
||||
self.redis.incrbyfloat(cache_key, quantity)
|
||||
self.redis.expire(cache_key, self._cache_ttl)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(
|
||||
f"Consumed quota: tenant={tenant_id}, "
|
||||
f"resource={resource_type}, quantity={quantity}"
|
||||
)
|
||||
|
||||
return usage_record
|
||||
|
||||
async def release_quota(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: float,
|
||||
usage_record_id: str,
|
||||
tenant_id: Optional[str] = None
|
||||
):
|
||||
"""Release quota (e.g., when job completes early)"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Update usage record
|
||||
stmt = update(UsageRecord).where(
|
||||
and_(
|
||||
UsageRecord.id == usage_record_id,
|
||||
UsageRecord.tenant_id == tenant_id
|
||||
)
|
||||
).values(
|
||||
quantity=UsageRecord.quantity - quantity,
|
||||
total_cost=UsageRecord.total_cost - await self._calculate_cost(resource_type, quantity)
|
||||
)
|
||||
|
||||
result = self.db.execute(stmt)
|
||||
|
||||
if result.rowcount > 0:
|
||||
# Update quota usage
|
||||
await self._update_quota_usage(tenant_id, resource_type, -quantity)
|
||||
|
||||
# Update cache
|
||||
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
|
||||
if self.redis:
|
||||
current = self.redis.get(cache_key)
|
||||
if current:
|
||||
self.redis.incrbyfloat(cache_key, -quantity)
|
||||
self.redis.expire(cache_key, self._cache_ttl)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(
|
||||
f"Released quota: tenant={tenant_id}, "
|
||||
f"resource={resource_type}, quantity={quantity}"
|
||||
)
|
||||
|
||||
async def get_quota_status(
|
||||
self,
|
||||
resource_type: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current quota status for a tenant"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
# Get all quotas for tenant
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.is_active == True
|
||||
)
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
stmt = stmt.where(TenantQuota.resource_type == resource_type)
|
||||
|
||||
quotas = self.db.execute(stmt).scalars().all()
|
||||
|
||||
status = {
|
||||
"tenant_id": tenant_id,
|
||||
"quotas": {},
|
||||
"summary": {
|
||||
"total_resources": len(quotas),
|
||||
"over_limit": 0,
|
||||
"near_limit": 0
|
||||
}
|
||||
}
|
||||
|
||||
for quota in quotas:
|
||||
current_usage = await self._get_current_usage(tenant_id, quota.resource_type)
|
||||
usage_percent = (current_usage / quota.limit_value) * 100 if quota.limit_value > 0 else 0
|
||||
|
||||
quota_status = {
|
||||
"limit": float(quota.limit_value),
|
||||
"used": float(current_usage),
|
||||
"remaining": float(quota.limit_value - current_usage),
|
||||
"usage_percent": round(usage_percent, 2),
|
||||
"period": quota.period_type,
|
||||
"period_start": quota.period_start.isoformat(),
|
||||
"period_end": quota.period_end.isoformat()
|
||||
}
|
||||
|
||||
status["quotas"][quota.resource_type] = quota_status
|
||||
|
||||
# Update summary
|
||||
if usage_percent >= 100:
|
||||
status["summary"]["over_limit"] += 1
|
||||
elif usage_percent >= 80:
|
||||
status["summary"]["near_limit"] += 1
|
||||
|
||||
return status
|
||||
|
||||
@asynccontextmanager
|
||||
async def quota_reservation(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: float,
|
||||
timeout: int = 300, # 5 minutes
|
||||
tenant_id: Optional[str] = None
|
||||
):
|
||||
"""Context manager for temporary quota reservation"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
reservation_id = f"reserve:{tenant_id}:{resource_type}:{datetime.utcnow().timestamp()}"
|
||||
|
||||
try:
|
||||
# Reserve quota
|
||||
await self.check_quota(resource_type, quantity, tenant_id)
|
||||
|
||||
# Store reservation in Redis
|
||||
if self.redis:
|
||||
reservation_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"resource_type": resource_type,
|
||||
"quantity": quantity,
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
self.redis.setex(
|
||||
f"reservation:{reservation_id}",
|
||||
timeout,
|
||||
json.dumps(reservation_data)
|
||||
)
|
||||
|
||||
yield reservation_id
|
||||
|
||||
finally:
|
||||
# Clean up reservation
|
||||
if self.redis:
|
||||
self.redis.delete(f"reservation:{reservation_id}")
|
||||
|
||||
async def reset_quota_period(self, tenant_id: str, resource_type: str):
|
||||
"""Reset quota for a new period"""
|
||||
|
||||
# Get current quota
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True
|
||||
)
|
||||
)
|
||||
|
||||
quota = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not quota:
|
||||
return
|
||||
|
||||
# Calculate new period
|
||||
now = datetime.utcnow()
|
||||
if quota.period_type == "monthly":
|
||||
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = (period_start + timedelta(days=32)).replace(day=1) - timedelta(days=1)
|
||||
elif quota.period_type == "weekly":
|
||||
days_since_monday = now.weekday()
|
||||
period_start = (now - timedelta(days=days_since_monday)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
period_end = period_start + timedelta(days=6)
|
||||
else: # daily
|
||||
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + timedelta(days=1)
|
||||
|
||||
# Update quota
|
||||
quota.period_start = period_start
|
||||
quota.period_end = period_end
|
||||
quota.used_value = 0
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# Clear cache
|
||||
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
|
||||
if self.redis:
|
||||
self.redis.delete(cache_key)
|
||||
|
||||
self.logger.info(
|
||||
f"Reset quota period: tenant={tenant_id}, "
|
||||
f"resource={resource_type}, period={quota.period_type}"
|
||||
)
|
||||
|
||||
async def get_quota_alerts(self, tenant_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get quota alerts for tenants approaching or exceeding limits"""
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise TenantError("No tenant context found")
|
||||
|
||||
alerts = []
|
||||
status = await self.get_quota_status(tenant_id=tenant_id)
|
||||
|
||||
for resource_type, quota_status in status["quotas"].items():
|
||||
usage_percent = quota_status["usage_percent"]
|
||||
|
||||
if usage_percent >= 100:
|
||||
alerts.append({
|
||||
"severity": "critical",
|
||||
"resource_type": resource_type,
|
||||
"message": f"Quota exceeded for {resource_type}",
|
||||
"usage_percent": usage_percent,
|
||||
"used": quota_status["used"],
|
||||
"limit": quota_status["limit"]
|
||||
})
|
||||
elif usage_percent >= 90:
|
||||
alerts.append({
|
||||
"severity": "warning",
|
||||
"resource_type": resource_type,
|
||||
"message": f"Quota almost exceeded for {resource_type}",
|
||||
"usage_percent": usage_percent,
|
||||
"used": quota_status["used"],
|
||||
"limit": quota_status["limit"]
|
||||
})
|
||||
elif usage_percent >= 80:
|
||||
alerts.append({
|
||||
"severity": "info",
|
||||
"resource_type": resource_type,
|
||||
"message": f"Quota usage high for {resource_type}",
|
||||
"usage_percent": usage_percent,
|
||||
"used": quota_status["used"],
|
||||
"limit": quota_status["limit"]
|
||||
})
|
||||
|
||||
return alerts
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _get_current_quota(self, tenant_id: str, resource_type: str) -> Optional[TenantQuota]:
|
||||
"""Get current quota for tenant and resource type"""
|
||||
|
||||
cache_key = f"quota:{tenant_id}:{resource_type}"
|
||||
|
||||
# Check cache first
|
||||
if self.redis:
|
||||
cached = self.redis.get(cache_key)
|
||||
if cached:
|
||||
quota_data = json.loads(cached)
|
||||
quota = TenantQuota(**quota_data)
|
||||
# Check if still valid
|
||||
if quota.period_end >= datetime.utcnow():
|
||||
return quota
|
||||
|
||||
# Query database
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True,
|
||||
TenantQuota.period_start <= datetime.utcnow(),
|
||||
TenantQuota.period_end >= datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
quota = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
# Cache result
|
||||
if quota and self.redis:
|
||||
quota_data = {
|
||||
"id": str(quota.id),
|
||||
"tenant_id": str(quota.tenant_id),
|
||||
"resource_type": quota.resource_type,
|
||||
"limit_value": float(quota.limit_value),
|
||||
"used_value": float(quota.used_value),
|
||||
"period_start": quota.period_start.isoformat(),
|
||||
"period_end": quota.period_end.isoformat()
|
||||
}
|
||||
self.redis.setex(
|
||||
cache_key,
|
||||
self._cache_ttl,
|
||||
json.dumps(quota_data)
|
||||
)
|
||||
|
||||
return quota
|
||||
|
||||
async def _get_current_usage(self, tenant_id: str, resource_type: str) -> float:
|
||||
"""Get current usage for tenant and resource type"""
|
||||
|
||||
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
|
||||
|
||||
# Check cache first
|
||||
if self.redis:
|
||||
cached = self.redis.get(cache_key)
|
||||
if cached:
|
||||
return float(cached)
|
||||
|
||||
# Query database
|
||||
stmt = select(func.sum(UsageRecord.quantity)).where(
|
||||
and_(
|
||||
UsageRecord.tenant_id == tenant_id,
|
||||
UsageRecord.resource_type == resource_type,
|
||||
UsageRecord.usage_start >= func.date_trunc('month', func.current_date())
|
||||
)
|
||||
)
|
||||
|
||||
result = self.db.execute(stmt).scalar()
|
||||
usage = float(result) if result else 0.0
|
||||
|
||||
# Cache result
|
||||
if self.redis:
|
||||
self.redis.setex(cache_key, self._cache_ttl, str(usage))
|
||||
|
||||
return usage
|
||||
|
||||
async def _update_quota_usage(self, tenant_id: str, resource_type: str, quantity: float):
|
||||
"""Update quota usage in database"""
|
||||
|
||||
stmt = update(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True
|
||||
)
|
||||
).values(
|
||||
used_value=TenantQuota.used_value + quantity
|
||||
)
|
||||
|
||||
self.db.execute(stmt)
|
||||
|
||||
async def _get_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Get tenant by ID"""
|
||||
stmt = select(Tenant).where(Tenant.id == tenant_id)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def _get_unit_for_resource(self, resource_type: str) -> str:
|
||||
"""Get unit for resource type"""
|
||||
unit_map = {
|
||||
"gpu_hours": "hours",
|
||||
"storage_gb": "gb",
|
||||
"api_calls": "calls",
|
||||
"bandwidth_gb": "gb",
|
||||
"compute_hours": "hours"
|
||||
}
|
||||
return unit_map.get(resource_type, "units")
|
||||
|
||||
async def _get_unit_price(self, resource_type: str) -> float:
|
||||
"""Get unit price for resource type"""
|
||||
# In a real implementation, this would come from a pricing table
|
||||
price_map = {
|
||||
"gpu_hours": 0.50, # $0.50 per hour
|
||||
"storage_gb": 0.02, # $0.02 per GB per month
|
||||
"api_calls": 0.0001, # $0.0001 per call
|
||||
"bandwidth_gb": 0.01, # $0.01 per GB
|
||||
"compute_hours": 0.30 # $0.30 per hour
|
||||
}
|
||||
return price_map.get(resource_type, 0.0)
|
||||
|
||||
async def _calculate_cost(self, resource_type: str, quantity: float) -> float:
|
||||
"""Calculate cost for resource usage"""
|
||||
unit_price = await self._get_unit_price(resource_type)
|
||||
return unit_price * quantity
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
"""Middleware to enforce quotas on API endpoints"""
|
||||
|
||||
def __init__(self, quota_service: QuotaEnforcementService):
|
||||
self.quota_service = quota_service
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Resource costs per endpoint
|
||||
self.endpoint_costs = {
|
||||
"/api/v1/jobs": {"resource": "compute_hours", "cost": 0.1},
|
||||
"/api/v1/models": {"resource": "storage_gb", "cost": 0.1},
|
||||
"/api/v1/data": {"resource": "storage_gb", "cost": 0.05},
|
||||
"/api/v1/analytics": {"resource": "api_calls", "cost": 1}
|
||||
}
|
||||
|
||||
async def check_endpoint_quota(self, endpoint: str, estimated_cost: float = 0):
|
||||
"""Check if endpoint call is within quota"""
|
||||
|
||||
resource_config = self.endpoint_costs.get(endpoint)
|
||||
if not resource_config:
|
||||
return # No quota check for this endpoint
|
||||
|
||||
try:
|
||||
await self.quota_service.check_quota(
|
||||
resource_config["resource"],
|
||||
resource_config["cost"] + estimated_cost
|
||||
)
|
||||
except QuotaExceededError as e:
|
||||
self.logger.warning(f"Quota exceeded for endpoint {endpoint}: {e}")
|
||||
raise
|
||||
|
||||
async def consume_endpoint_quota(self, endpoint: str, actual_cost: float = 0):
|
||||
"""Consume quota after endpoint execution"""
|
||||
|
||||
resource_config = self.endpoint_costs.get(endpoint)
|
||||
if not resource_config:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.quota_service.consume_quota(
|
||||
resource_config["resource"],
|
||||
resource_config["cost"] + actual_cost
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to consume quota for {endpoint}: {e}")
|
||||
# Don't fail the request, just log the error
|
||||
@ -10,6 +10,7 @@ from sqlmodel import Session
|
||||
|
||||
from ..config import settings
|
||||
from ..domain import Job, JobReceipt
|
||||
from .zk_proofs import zk_proof_service
|
||||
|
||||
|
||||
class ReceiptService:
|
||||
@ -24,12 +25,13 @@ class ReceiptService:
|
||||
attest_bytes = bytes.fromhex(settings.receipt_attestation_key_hex)
|
||||
self._attestation_signer = ReceiptSigner(attest_bytes)
|
||||
|
||||
def create_receipt(
|
||||
async def create_receipt(
|
||||
self,
|
||||
job: Job,
|
||||
miner_id: str,
|
||||
job_result: Dict[str, Any] | None,
|
||||
result_metrics: Dict[str, Any] | None,
|
||||
privacy_level: Optional[str] = None,
|
||||
) -> Dict[str, Any] | None:
|
||||
if self._signer is None:
|
||||
return None
|
||||
@ -67,6 +69,32 @@ class ReceiptService:
|
||||
attestation_payload.pop("attestations", None)
|
||||
attestation_payload.pop("signature", None)
|
||||
payload["attestations"].append(self._attestation_signer.sign(attestation_payload))
|
||||
|
||||
# Generate ZK proof if privacy is requested
|
||||
if privacy_level and zk_proof_service.is_enabled():
|
||||
try:
|
||||
# Create receipt model for ZK proof generation
|
||||
receipt_model = JobReceipt(
|
||||
job_id=job.id,
|
||||
receipt_id=payload["receipt_id"],
|
||||
payload=payload
|
||||
)
|
||||
|
||||
# Generate ZK proof
|
||||
zk_proof = await zk_proof_service.generate_receipt_proof(
|
||||
receipt=receipt_model,
|
||||
job_result=job_result or {},
|
||||
privacy_level=privacy_level
|
||||
)
|
||||
|
||||
if zk_proof:
|
||||
payload["zk_proof"] = zk_proof
|
||||
payload["privacy_level"] = privacy_level
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't fail receipt creation
|
||||
print(f"Failed to generate ZK proof: {e}")
|
||||
|
||||
receipt_row = JobReceipt(job_id=job.id, receipt_id=payload["receipt_id"], payload=payload)
|
||||
self.session.add(receipt_row)
|
||||
return payload
|
||||
|
||||
690
apps/coordinator-api/src/app/services/tenant_management.py
Normal file
690
apps/coordinator-api/src/app/services/tenant_management.py
Normal file
@ -0,0 +1,690 @@
|
||||
"""
|
||||
Tenant management service for multi-tenant AITBC coordinator
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, update, delete, and_, or_, func
|
||||
|
||||
from ..models.multitenant import (
|
||||
Tenant, TenantUser, TenantQuota, TenantApiKey,
|
||||
TenantAuditLog, TenantStatus
|
||||
)
|
||||
from ..database import get_db
|
||||
from ..exceptions import TenantError, QuotaExceededError
|
||||
|
||||
|
||||
class TenantManagementService:
|
||||
"""Service for managing tenants in multi-tenant environment"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
async def create_tenant(
|
||||
self,
|
||||
name: str,
|
||||
contact_email: str,
|
||||
plan: str = "trial",
|
||||
domain: Optional[str] = None,
|
||||
settings: Optional[Dict[str, Any]] = None,
|
||||
features: Optional[Dict[str, Any]] = None
|
||||
) -> Tenant:
|
||||
"""Create a new tenant"""
|
||||
|
||||
# Generate unique slug
|
||||
slug = self._generate_slug(name)
|
||||
if await self._tenant_exists(slug=slug):
|
||||
raise TenantError(f"Tenant with slug '{slug}' already exists")
|
||||
|
||||
# Check domain uniqueness if provided
|
||||
if domain and await self._tenant_exists(domain=domain):
|
||||
raise TenantError(f"Domain '{domain}' is already in use")
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=name,
|
||||
slug=slug,
|
||||
domain=domain,
|
||||
contact_email=contact_email,
|
||||
plan=plan,
|
||||
status=TenantStatus.PENDING.value,
|
||||
settings=settings or {},
|
||||
features=features or {}
|
||||
)
|
||||
|
||||
self.db.add(tenant)
|
||||
self.db.flush()
|
||||
|
||||
# Create default quotas
|
||||
await self._create_default_quotas(tenant.id, plan)
|
||||
|
||||
# Log creation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_created",
|
||||
event_category="lifecycle",
|
||||
actor_id="system",
|
||||
actor_type="system",
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
new_values={"name": name, "plan": plan}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Created tenant: {tenant.id} ({name})")
|
||||
|
||||
return tenant
|
||||
|
||||
async def get_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Get tenant by ID"""
|
||||
stmt = select(Tenant).where(Tenant.id == tenant_id)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def get_tenant_by_slug(self, slug: str) -> Optional[Tenant]:
|
||||
"""Get tenant by slug"""
|
||||
stmt = select(Tenant).where(Tenant.slug == slug)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def get_tenant_by_domain(self, domain: str) -> Optional[Tenant]:
|
||||
"""Get tenant by domain"""
|
||||
stmt = select(Tenant).where(Tenant.domain == domain)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def update_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
updates: Dict[str, Any],
|
||||
actor_id: str,
|
||||
actor_type: str = "user"
|
||||
) -> Tenant:
|
||||
"""Update tenant information"""
|
||||
|
||||
tenant = await self.get_tenant(tenant_id)
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
# Store old values for audit
|
||||
old_values = {
|
||||
"name": tenant.name,
|
||||
"contact_email": tenant.contact_email,
|
||||
"billing_email": tenant.billing_email,
|
||||
"settings": tenant.settings,
|
||||
"features": tenant.features
|
||||
}
|
||||
|
||||
# Apply updates
|
||||
for key, value in updates.items():
|
||||
if hasattr(tenant, key):
|
||||
setattr(tenant, key, value)
|
||||
|
||||
tenant.updated_at = datetime.utcnow()
|
||||
|
||||
# Log update
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_updated",
|
||||
event_category="lifecycle",
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
old_values=old_values,
|
||||
new_values=updates
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Updated tenant: {tenant_id}")
|
||||
|
||||
return tenant
|
||||
|
||||
async def activate_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
actor_id: str,
|
||||
actor_type: str = "user"
|
||||
) -> Tenant:
|
||||
"""Activate a tenant"""
|
||||
|
||||
tenant = await self.get_tenant(tenant_id)
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
if tenant.status == TenantStatus.ACTIVE.value:
|
||||
return tenant
|
||||
|
||||
tenant.status = TenantStatus.ACTIVE.value
|
||||
tenant.activated_at = datetime.utcnow()
|
||||
tenant.updated_at = datetime.utcnow()
|
||||
|
||||
# Log activation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_activated",
|
||||
event_category="lifecycle",
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
old_values={"status": "pending"},
|
||||
new_values={"status": "active"}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Activated tenant: {tenant_id}")
|
||||
|
||||
return tenant
|
||||
|
||||
async def deactivate_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
reason: Optional[str] = None,
|
||||
actor_id: str = "system",
|
||||
actor_type: str = "system"
|
||||
) -> Tenant:
|
||||
"""Deactivate a tenant"""
|
||||
|
||||
tenant = await self.get_tenant(tenant_id)
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
if tenant.status == TenantStatus.INACTIVE.value:
|
||||
return tenant
|
||||
|
||||
old_status = tenant.status
|
||||
tenant.status = TenantStatus.INACTIVE.value
|
||||
tenant.deactivated_at = datetime.utcnow()
|
||||
tenant.updated_at = datetime.utcnow()
|
||||
|
||||
# Revoke all API keys
|
||||
await self._revoke_all_api_keys(tenant_id)
|
||||
|
||||
# Log deactivation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_deactivated",
|
||||
event_category="lifecycle",
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
old_values={"status": old_status},
|
||||
new_values={"status": "inactive", "reason": reason}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Deactivated tenant: {tenant_id} (reason: {reason})")
|
||||
|
||||
return tenant
|
||||
|
||||
async def suspend_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
reason: Optional[str] = None,
|
||||
actor_id: str = "system",
|
||||
actor_type: str = "system"
|
||||
) -> Tenant:
|
||||
"""Suspend a tenant temporarily"""
|
||||
|
||||
tenant = await self.get_tenant(tenant_id)
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
old_status = tenant.status
|
||||
tenant.status = TenantStatus.SUSPENDED.value
|
||||
tenant.updated_at = datetime.utcnow()
|
||||
|
||||
# Log suspension
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant.id,
|
||||
event_type="tenant_suspended",
|
||||
event_category="lifecycle",
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant.id),
|
||||
old_values={"status": old_status},
|
||||
new_values={"status": "suspended", "reason": reason}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.warning(f"Suspended tenant: {tenant_id} (reason: {reason})")
|
||||
|
||||
return tenant
|
||||
|
||||
async def add_user_to_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: str = "member",
|
||||
permissions: Optional[List[str]] = None,
|
||||
actor_id: str = "system"
|
||||
) -> TenantUser:
|
||||
"""Add a user to a tenant"""
|
||||
|
||||
# Check if user already exists
|
||||
stmt = select(TenantUser).where(
|
||||
and_(TenantUser.tenant_id == tenant_id, TenantUser.user_id == user_id)
|
||||
)
|
||||
existing = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
raise TenantError(f"User {user_id} already belongs to tenant {tenant_id}")
|
||||
|
||||
# Create tenant user
|
||||
tenant_user = TenantUser(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
permissions=permissions or [],
|
||||
joined_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.add(tenant_user)
|
||||
|
||||
# Log addition
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant_id,
|
||||
event_type="user_added",
|
||||
event_category="access",
|
||||
actor_id=actor_id,
|
||||
actor_type="system",
|
||||
resource_type="tenant_user",
|
||||
resource_id=str(tenant_user.id),
|
||||
new_values={"user_id": user_id, "role": role}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Added user {user_id} to tenant {tenant_id}")
|
||||
|
||||
return tenant_user
|
||||
|
||||
async def remove_user_from_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
actor_id: str = "system"
|
||||
) -> bool:
|
||||
"""Remove a user from a tenant"""
|
||||
|
||||
stmt = select(TenantUser).where(
|
||||
and_(TenantUser.tenant_id == tenant_id, TenantUser.user_id == user_id)
|
||||
)
|
||||
tenant_user = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not tenant_user:
|
||||
return False
|
||||
|
||||
# Store for audit
|
||||
old_values = {
|
||||
"user_id": user_id,
|
||||
"role": tenant_user.role,
|
||||
"permissions": tenant_user.permissions
|
||||
}
|
||||
|
||||
self.db.delete(tenant_user)
|
||||
|
||||
# Log removal
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant_id,
|
||||
event_type="user_removed",
|
||||
event_category="access",
|
||||
actor_id=actor_id,
|
||||
actor_type="system",
|
||||
resource_type="tenant_user",
|
||||
resource_id=str(tenant_user.id),
|
||||
old_values=old_values
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Removed user {user_id} from tenant {tenant_id}")
|
||||
|
||||
return True
|
||||
|
||||
async def create_api_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
permissions: Optional[List[str]] = None,
|
||||
rate_limit: Optional[int] = None,
|
||||
allowed_ips: Optional[List[str]] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
created_by: str = "system"
|
||||
) -> TenantApiKey:
|
||||
"""Create a new API key for a tenant"""
|
||||
|
||||
# Generate secure key
|
||||
key_id = f"ak_{secrets.token_urlsafe(16)}"
|
||||
api_key = f"ask_{secrets.token_urlsafe(32)}"
|
||||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
|
||||
key_prefix = api_key[:8]
|
||||
|
||||
# Create API key record
|
||||
api_key_record = TenantApiKey(
|
||||
tenant_id=tenant_id,
|
||||
key_id=key_id,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
name=name,
|
||||
permissions=permissions or [],
|
||||
rate_limit=rate_limit,
|
||||
allowed_ips=allowed_ips,
|
||||
expires_at=expires_at,
|
||||
created_by=created_by
|
||||
)
|
||||
|
||||
self.db.add(api_key_record)
|
||||
self.db.flush()
|
||||
|
||||
# Log creation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant_id,
|
||||
event_type="api_key_created",
|
||||
event_category="security",
|
||||
actor_id=created_by,
|
||||
actor_type="user",
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key_record.id),
|
||||
new_values={
|
||||
"key_id": key_id,
|
||||
"name": name,
|
||||
"permissions": permissions,
|
||||
"rate_limit": rate_limit
|
||||
}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Created API key {key_id} for tenant {tenant_id}")
|
||||
|
||||
# Return the key (only time it's shown)
|
||||
api_key_record.api_key = api_key
|
||||
return api_key_record
|
||||
|
||||
async def revoke_api_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
key_id: str,
|
||||
actor_id: str = "system"
|
||||
) -> bool:
|
||||
"""Revoke an API key"""
|
||||
|
||||
stmt = select(TenantApiKey).where(
|
||||
and_(
|
||||
TenantApiKey.tenant_id == tenant_id,
|
||||
TenantApiKey.key_id == key_id,
|
||||
TenantApiKey.is_active == True
|
||||
)
|
||||
)
|
||||
api_key = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
api_key.is_active = False
|
||||
api_key.revoked_at = datetime.utcnow()
|
||||
|
||||
# Log revocation
|
||||
await self._log_audit_event(
|
||||
tenant_id=tenant_id,
|
||||
event_type="api_key_revoked",
|
||||
event_category="security",
|
||||
actor_id=actor_id,
|
||||
actor_type="user",
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key.id),
|
||||
old_values={"key_id": key_id, "is_active": True}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info(f"Revoked API key {key_id} for tenant {tenant_id}")
|
||||
|
||||
return True
|
||||
|
||||
async def get_tenant_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
resource_type: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get usage statistics for a tenant"""
|
||||
|
||||
from ..models.multitenant import UsageRecord
|
||||
|
||||
# Default to last 30 days
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
if not start_date:
|
||||
start_date = end_date - timedelta(days=30)
|
||||
|
||||
# Build query
|
||||
stmt = select(
|
||||
UsageRecord.resource_type,
|
||||
func.sum(UsageRecord.quantity).label("total_quantity"),
|
||||
func.sum(UsageRecord.total_cost).label("total_cost"),
|
||||
func.count(UsageRecord.id).label("record_count")
|
||||
).where(
|
||||
and_(
|
||||
UsageRecord.tenant_id == tenant_id,
|
||||
UsageRecord.usage_start >= start_date,
|
||||
UsageRecord.usage_end <= end_date
|
||||
)
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
stmt = stmt.where(UsageRecord.resource_type == resource_type)
|
||||
|
||||
stmt = stmt.group_by(UsageRecord.resource_type)
|
||||
|
||||
results = self.db.execute(stmt).all()
|
||||
|
||||
# Format results
|
||||
usage = {
|
||||
"period": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
},
|
||||
"by_resource": {}
|
||||
}
|
||||
|
||||
for result in results:
|
||||
usage["by_resource"][result.resource_type] = {
|
||||
"quantity": float(result.total_quantity),
|
||||
"cost": float(result.total_cost),
|
||||
"records": result.record_count
|
||||
}
|
||||
|
||||
return usage
|
||||
|
||||
async def get_tenant_quotas(self, tenant_id: str) -> List[TenantQuota]:
|
||||
"""Get all quotas for a tenant"""
|
||||
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.is_active == True
|
||||
)
|
||||
)
|
||||
|
||||
return self.db.execute(stmt).scalars().all()
|
||||
|
||||
async def check_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
resource_type: str,
|
||||
quantity: float
|
||||
) -> bool:
|
||||
"""Check if tenant has sufficient quota for a resource"""
|
||||
|
||||
# Get current quota
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True,
|
||||
TenantQuota.period_start <= datetime.utcnow(),
|
||||
TenantQuota.period_end >= datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
quota = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not quota:
|
||||
# No quota set, deny by default
|
||||
return False
|
||||
|
||||
# Check if usage + quantity exceeds limit
|
||||
if quota.used_value + quantity > quota.limit_value:
|
||||
raise QuotaExceededError(
|
||||
f"Quota exceeded for {resource_type}: "
|
||||
f"{quota.used_value + quantity}/{quota.limit_value}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def update_quota_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
resource_type: str,
|
||||
quantity: float
|
||||
):
|
||||
"""Update quota usage for a tenant"""
|
||||
|
||||
# Get current quota
|
||||
stmt = select(TenantQuota).where(
|
||||
and_(
|
||||
TenantQuota.tenant_id == tenant_id,
|
||||
TenantQuota.resource_type == resource_type,
|
||||
TenantQuota.is_active == True,
|
||||
TenantQuota.period_start <= datetime.utcnow(),
|
||||
TenantQuota.period_end >= datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
quota = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if quota:
|
||||
quota.used_value += quantity
|
||||
self.db.commit()
|
||||
|
||||
# Private methods
|
||||
|
||||
def _generate_slug(self, name: str) -> str:
|
||||
"""Generate a unique slug from name"""
|
||||
import re
|
||||
# Convert to lowercase and replace spaces with hyphens
|
||||
base = re.sub(r'[^a-z0-9]+', '-', name.lower()).strip('-')
|
||||
# Add random suffix for uniqueness
|
||||
suffix = secrets.token_urlsafe(4)
|
||||
return f"{base}-{suffix}"
|
||||
|
||||
async def _tenant_exists(self, slug: Optional[str] = None, domain: Optional[str] = None) -> bool:
|
||||
"""Check if tenant exists by slug or domain"""
|
||||
|
||||
conditions = []
|
||||
if slug:
|
||||
conditions.append(Tenant.slug == slug)
|
||||
if domain:
|
||||
conditions.append(Tenant.domain == domain)
|
||||
|
||||
if not conditions:
|
||||
return False
|
||||
|
||||
stmt = select(func.count(Tenant.id)).where(or_(*conditions))
|
||||
count = self.db.execute(stmt).scalar()
|
||||
|
||||
return count > 0
|
||||
|
||||
async def _create_default_quotas(self, tenant_id: str, plan: str):
|
||||
"""Create default quotas based on plan"""
|
||||
|
||||
# Define quota templates by plan
|
||||
quota_templates = {
|
||||
"trial": {
|
||||
"gpu_hours": {"limit": 100, "period": "monthly"},
|
||||
"storage_gb": {"limit": 10, "period": "monthly"},
|
||||
"api_calls": {"limit": 10000, "period": "monthly"}
|
||||
},
|
||||
"basic": {
|
||||
"gpu_hours": {"limit": 500, "period": "monthly"},
|
||||
"storage_gb": {"limit": 100, "period": "monthly"},
|
||||
"api_calls": {"limit": 100000, "period": "monthly"}
|
||||
},
|
||||
"pro": {
|
||||
"gpu_hours": {"limit": 2000, "period": "monthly"},
|
||||
"storage_gb": {"limit": 1000, "period": "monthly"},
|
||||
"api_calls": {"limit": 1000000, "period": "monthly"}
|
||||
},
|
||||
"enterprise": {
|
||||
"gpu_hours": {"limit": 10000, "period": "monthly"},
|
||||
"storage_gb": {"limit": 10000, "period": "monthly"},
|
||||
"api_calls": {"limit": 10000000, "period": "monthly"}
|
||||
}
|
||||
}
|
||||
|
||||
quotas = quota_templates.get(plan, quota_templates["trial"])
|
||||
|
||||
# Create quota records
|
||||
now = datetime.utcnow()
|
||||
period_end = now.replace(day=1) + timedelta(days=32) # Next month
|
||||
period_end = period_end.replace(day=1) - timedelta(days=1) # Last day of current month
|
||||
|
||||
for resource_type, config in quotas.items():
|
||||
quota = TenantQuota(
|
||||
tenant_id=tenant_id,
|
||||
resource_type=resource_type,
|
||||
limit_value=config["limit"],
|
||||
used_value=0,
|
||||
period_type=config["period"],
|
||||
period_start=now,
|
||||
period_end=period_end
|
||||
)
|
||||
self.db.add(quota)
|
||||
|
||||
async def _revoke_all_api_keys(self, tenant_id: str):
|
||||
"""Revoke all API keys for a tenant"""
|
||||
|
||||
stmt = update(TenantApiKey).where(
|
||||
and_(
|
||||
TenantApiKey.tenant_id == tenant_id,
|
||||
TenantApiKey.is_active == True
|
||||
)
|
||||
).values(
|
||||
is_active=False,
|
||||
revoked_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.execute(stmt)
|
||||
|
||||
async def _log_audit_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
event_type: str,
|
||||
event_category: str,
|
||||
actor_id: str,
|
||||
actor_type: str,
|
||||
resource_type: str,
|
||||
resource_id: Optional[str] = None,
|
||||
old_values: Optional[Dict[str, Any]] = None,
|
||||
new_values: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log an audit event"""
|
||||
|
||||
audit_log = TenantAuditLog(
|
||||
tenant_id=tenant_id,
|
||||
event_type=event_type,
|
||||
event_category=event_category,
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
old_values=old_values,
|
||||
new_values=new_values,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
self.db.add(audit_log)
|
||||
654
apps/coordinator-api/src/app/services/usage_tracking.py
Normal file
654
apps/coordinator-api/src/app/services/usage_tracking.py
Normal file
@ -0,0 +1,654 @@
|
||||
"""
|
||||
Usage tracking and billing metrics service for multi-tenant AITBC coordinator
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, update, and_, or_, func, desc
|
||||
from dataclasses import dataclass, asdict
|
||||
from decimal import Decimal
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from ..models.multitenant import (
|
||||
UsageRecord, Invoice, Tenant, TenantQuota,
|
||||
TenantMetric
|
||||
)
|
||||
from ..exceptions import BillingError, TenantError
|
||||
from ..middleware.tenant_context import get_current_tenant_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageSummary:
|
||||
"""Usage summary for billing period"""
|
||||
tenant_id: str
|
||||
period_start: datetime
|
||||
period_end: datetime
|
||||
resources: Dict[str, Dict[str, Any]]
|
||||
total_cost: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class BillingEvent:
|
||||
"""Billing event for processing"""
|
||||
tenant_id: str
|
||||
event_type: str # usage, quota_adjustment, credit, charge
|
||||
resource_type: Optional[str]
|
||||
quantity: Decimal
|
||||
unit_price: Decimal
|
||||
total_amount: Decimal
|
||||
currency: str
|
||||
timestamp: datetime
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking usage and generating billing metrics"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# Pricing configuration
|
||||
self.pricing_config = {
|
||||
"gpu_hours": {"unit_price": Decimal("0.50"), "tiered": True},
|
||||
"storage_gb": {"unit_price": Decimal("0.02"), "tiered": True},
|
||||
"api_calls": {"unit_price": Decimal("0.0001"), "tiered": False},
|
||||
"bandwidth_gb": {"unit_price": Decimal("0.01"), "tiered": False},
|
||||
"compute_hours": {"unit_price": Decimal("0.30"), "tiered": True}
|
||||
}
|
||||
|
||||
# Tier pricing thresholds
|
||||
self.tier_thresholds = {
|
||||
"gpu_hours": [
|
||||
{"min": 0, "max": 100, "multiplier": 1.0},
|
||||
{"min": 101, "max": 500, "multiplier": 0.9},
|
||||
{"min": 501, "max": 2000, "multiplier": 0.8},
|
||||
{"min": 2001, "max": None, "multiplier": 0.7}
|
||||
],
|
||||
"storage_gb": [
|
||||
{"min": 0, "max": 100, "multiplier": 1.0},
|
||||
{"min": 101, "max": 1000, "multiplier": 0.85},
|
||||
{"min": 1001, "max": 10000, "multiplier": 0.75},
|
||||
{"min": 10001, "max": None, "multiplier": 0.65}
|
||||
],
|
||||
"compute_hours": [
|
||||
{"min": 0, "max": 200, "multiplier": 1.0},
|
||||
{"min": 201, "max": 1000, "multiplier": 0.9},
|
||||
{"min": 1001, "max": 5000, "multiplier": 0.8},
|
||||
{"min": 5001, "max": None, "multiplier": 0.7}
|
||||
]
|
||||
}
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
resource_type: str,
|
||||
quantity: Decimal,
|
||||
unit_price: Optional[Decimal] = None,
|
||||
job_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> UsageRecord:
|
||||
"""Record usage for billing"""
|
||||
|
||||
# Calculate unit price if not provided
|
||||
if not unit_price:
|
||||
unit_price = await self._calculate_unit_price(resource_type, quantity)
|
||||
|
||||
# Calculate total cost
|
||||
total_cost = unit_price * quantity
|
||||
|
||||
# Create usage record
|
||||
usage_record = UsageRecord(
|
||||
tenant_id=tenant_id,
|
||||
resource_type=resource_type,
|
||||
quantity=quantity,
|
||||
unit=self._get_unit_for_resource(resource_type),
|
||||
unit_price=unit_price,
|
||||
total_cost=total_cost,
|
||||
currency="USD",
|
||||
usage_start=datetime.utcnow(),
|
||||
usage_end=datetime.utcnow(),
|
||||
job_id=job_id,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
self.db.add(usage_record)
|
||||
self.db.commit()
|
||||
|
||||
# Emit billing event
|
||||
await self._emit_billing_event(BillingEvent(
|
||||
tenant_id=tenant_id,
|
||||
event_type="usage",
|
||||
resource_type=resource_type,
|
||||
quantity=quantity,
|
||||
unit_price=unit_price,
|
||||
total_amount=total_cost,
|
||||
currency="USD",
|
||||
timestamp=datetime.utcnow(),
|
||||
metadata=metadata or {}
|
||||
))
|
||||
|
||||
self.logger.info(
|
||||
f"Recorded usage: tenant={tenant_id}, "
|
||||
f"resource={resource_type}, quantity={quantity}, cost={total_cost}"
|
||||
)
|
||||
|
||||
return usage_record
|
||||
|
||||
async def get_usage_summary(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
resource_type: Optional[str] = None
|
||||
) -> UsageSummary:
|
||||
"""Get usage summary for a billing period"""
|
||||
|
||||
# Build query
|
||||
stmt = select(
|
||||
UsageRecord.resource_type,
|
||||
func.sum(UsageRecord.quantity).label("total_quantity"),
|
||||
func.sum(UsageRecord.total_cost).label("total_cost"),
|
||||
func.count(UsageRecord.id).label("record_count"),
|
||||
func.avg(UsageRecord.unit_price).label("avg_unit_price")
|
||||
).where(
|
||||
and_(
|
||||
UsageRecord.tenant_id == tenant_id,
|
||||
UsageRecord.usage_start >= start_date,
|
||||
UsageRecord.usage_end <= end_date
|
||||
)
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
stmt = stmt.where(UsageRecord.resource_type == resource_type)
|
||||
|
||||
stmt = stmt.group_by(UsageRecord.resource_type)
|
||||
|
||||
results = self.db.execute(stmt).all()
|
||||
|
||||
# Build summary
|
||||
resources = {}
|
||||
total_cost = Decimal("0")
|
||||
|
||||
for result in results:
|
||||
resources[result.resource_type] = {
|
||||
"quantity": float(result.total_quantity),
|
||||
"cost": float(result.total_cost),
|
||||
"records": result.record_count,
|
||||
"avg_unit_price": float(result.avg_unit_price)
|
||||
}
|
||||
total_cost += Decimal(str(result.total_cost))
|
||||
|
||||
return UsageSummary(
|
||||
tenant_id=tenant_id,
|
||||
period_start=start_date,
|
||||
period_end=end_date,
|
||||
resources=resources,
|
||||
total_cost=total_cost,
|
||||
currency="USD"
|
||||
)
|
||||
|
||||
async def generate_invoice(
|
||||
self,
|
||||
tenant_id: str,
|
||||
period_start: datetime,
|
||||
period_end: datetime,
|
||||
due_days: int = 30
|
||||
) -> Invoice:
|
||||
"""Generate invoice for billing period"""
|
||||
|
||||
# Check if invoice already exists
|
||||
existing = await self._get_existing_invoice(tenant_id, period_start, period_end)
|
||||
if existing:
|
||||
raise BillingError(f"Invoice already exists for period {period_start} to {period_end}")
|
||||
|
||||
# Get usage summary
|
||||
summary = await self.get_usage_summary(tenant_id, period_start, period_end)
|
||||
|
||||
# Generate invoice number
|
||||
invoice_number = await self._generate_invoice_number(tenant_id)
|
||||
|
||||
# Calculate line items
|
||||
line_items = []
|
||||
subtotal = Decimal("0")
|
||||
|
||||
for resource_type, usage in summary.resources.items():
|
||||
line_item = {
|
||||
"description": f"{resource_type.replace('_', ' ').title()} Usage",
|
||||
"quantity": usage["quantity"],
|
||||
"unit_price": usage["avg_unit_price"],
|
||||
"amount": usage["cost"]
|
||||
}
|
||||
line_items.append(line_item)
|
||||
subtotal += Decimal(str(usage["cost"]))
|
||||
|
||||
# Calculate tax (example: 10% for digital services)
|
||||
tax_rate = Decimal("0.10")
|
||||
tax_amount = subtotal * tax_rate
|
||||
total_amount = subtotal + tax_amount
|
||||
|
||||
# Create invoice
|
||||
invoice = Invoice(
|
||||
tenant_id=tenant_id,
|
||||
invoice_number=invoice_number,
|
||||
status="draft",
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
due_date=period_end + timedelta(days=due_days),
|
||||
subtotal=subtotal,
|
||||
tax_amount=tax_amount,
|
||||
total_amount=total_amount,
|
||||
currency="USD",
|
||||
line_items=line_items
|
||||
)
|
||||
|
||||
self.db.add(invoice)
|
||||
self.db.commit()
|
||||
|
||||
self.logger.info(
|
||||
f"Generated invoice {invoice_number} for tenant {tenant_id}: "
|
||||
f"${total_amount}"
|
||||
)
|
||||
|
||||
return invoice
|
||||
|
||||
async def get_billing_metrics(
|
||||
self,
|
||||
tenant_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get billing metrics and analytics"""
|
||||
|
||||
# Default to last 30 days
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
if not start_date:
|
||||
start_date = end_date - timedelta(days=30)
|
||||
|
||||
# Build base query
|
||||
base_conditions = [
|
||||
UsageRecord.usage_start >= start_date,
|
||||
UsageRecord.usage_end <= end_date
|
||||
]
|
||||
|
||||
if tenant_id:
|
||||
base_conditions.append(UsageRecord.tenant_id == tenant_id)
|
||||
|
||||
# Total usage and cost
|
||||
stmt = select(
|
||||
func.sum(UsageRecord.quantity).label("total_quantity"),
|
||||
func.sum(UsageRecord.total_cost).label("total_cost"),
|
||||
func.count(UsageRecord.id).label("total_records"),
|
||||
func.count(func.distinct(UsageRecord.tenant_id)).label("active_tenants")
|
||||
).where(and_(*base_conditions))
|
||||
|
||||
totals = self.db.execute(stmt).first()
|
||||
|
||||
# Usage by resource type
|
||||
stmt = select(
|
||||
UsageRecord.resource_type,
|
||||
func.sum(UsageRecord.quantity).label("quantity"),
|
||||
func.sum(UsageRecord.total_cost).label("cost")
|
||||
).where(and_(*base_conditions)).group_by(UsageRecord.resource_type)
|
||||
|
||||
by_resource = self.db.execute(stmt).all()
|
||||
|
||||
# Top tenants by usage
|
||||
if not tenant_id:
|
||||
stmt = select(
|
||||
UsageRecord.tenant_id,
|
||||
func.sum(UsageRecord.total_cost).label("total_cost")
|
||||
).where(and_(*base_conditions)).group_by(
|
||||
UsageRecord.tenant_id
|
||||
).order_by(desc("total_cost")).limit(10)
|
||||
|
||||
top_tenants = self.db.execute(stmt).all()
|
||||
else:
|
||||
top_tenants = []
|
||||
|
||||
# Daily usage trend
|
||||
stmt = select(
|
||||
func.date(UsageRecord.usage_start).label("date"),
|
||||
func.sum(UsageRecord.total_cost).label("daily_cost")
|
||||
).where(and_(*base_conditions)).group_by(
|
||||
func.date(UsageRecord.usage_start)
|
||||
).order_by("date")
|
||||
|
||||
daily_trend = self.db.execute(stmt).all()
|
||||
|
||||
# Assemble metrics
|
||||
metrics = {
|
||||
"period": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
},
|
||||
"totals": {
|
||||
"quantity": float(totals.total_quantity or 0),
|
||||
"cost": float(totals.total_cost or 0),
|
||||
"records": totals.total_records or 0,
|
||||
"active_tenants": totals.active_tenants or 0
|
||||
},
|
||||
"by_resource": {
|
||||
r.resource_type: {
|
||||
"quantity": float(r.quantity),
|
||||
"cost": float(r.cost)
|
||||
}
|
||||
for r in by_resource
|
||||
},
|
||||
"top_tenants": [
|
||||
{
|
||||
"tenant_id": str(t.tenant_id),
|
||||
"cost": float(t.total_cost)
|
||||
}
|
||||
for t in top_tenants
|
||||
],
|
||||
"daily_trend": [
|
||||
{
|
||||
"date": d.date.isoformat(),
|
||||
"cost": float(d.daily_cost)
|
||||
}
|
||||
for d in daily_trend
|
||||
]
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
async def process_billing_events(self, events: List[BillingEvent]) -> bool:
|
||||
"""Process batch of billing events"""
|
||||
|
||||
try:
|
||||
for event in events:
|
||||
if event.event_type == "usage":
|
||||
# Already recorded in record_usage
|
||||
continue
|
||||
elif event.event_type == "credit":
|
||||
await self._apply_credit(event)
|
||||
elif event.event_type == "charge":
|
||||
await self._apply_charge(event)
|
||||
elif event.event_type == "quota_adjustment":
|
||||
await self._adjust_quota(event)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process billing events: {e}")
|
||||
return False
|
||||
|
||||
async def export_usage_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
format: str = "csv"
|
||||
) -> str:
|
||||
"""Export usage data in specified format"""
|
||||
|
||||
# Get usage records
|
||||
stmt = select(UsageRecord).where(
|
||||
and_(
|
||||
UsageRecord.tenant_id == tenant_id,
|
||||
UsageRecord.usage_start >= start_date,
|
||||
UsageRecord.usage_end <= end_date
|
||||
)
|
||||
).order_by(UsageRecord.usage_start)
|
||||
|
||||
records = self.db.execute(stmt).scalars().all()
|
||||
|
||||
if format == "csv":
|
||||
return await self._export_csv(records)
|
||||
elif format == "json":
|
||||
return await self._export_json(records)
|
||||
else:
|
||||
raise BillingError(f"Unsupported export format: {format}")
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _calculate_unit_price(
|
||||
self,
|
||||
resource_type: str,
|
||||
quantity: Decimal
|
||||
) -> Decimal:
|
||||
"""Calculate unit price with tiered pricing"""
|
||||
|
||||
config = self.pricing_config.get(resource_type)
|
||||
if not config:
|
||||
return Decimal("0")
|
||||
|
||||
base_price = config["unit_price"]
|
||||
|
||||
if not config.get("tiered", False):
|
||||
return base_price
|
||||
|
||||
# Find applicable tier
|
||||
tiers = self.tier_thresholds.get(resource_type, [])
|
||||
quantity_float = float(quantity)
|
||||
|
||||
for tier in tiers:
|
||||
if (tier["min"] is None or quantity_float >= tier["min"]) and \
|
||||
(tier["max"] is None or quantity_float <= tier["max"]):
|
||||
return base_price * Decimal(str(tier["multiplier"]))
|
||||
|
||||
# Default to highest tier
|
||||
return base_price * Decimal("0.5")
|
||||
|
||||
def _get_unit_for_resource(self, resource_type: str) -> str:
|
||||
"""Get unit for resource type"""
|
||||
unit_map = {
|
||||
"gpu_hours": "hours",
|
||||
"storage_gb": "gb",
|
||||
"api_calls": "calls",
|
||||
"bandwidth_gb": "gb",
|
||||
"compute_hours": "hours"
|
||||
}
|
||||
return unit_map.get(resource_type, "units")
|
||||
|
||||
async def _emit_billing_event(self, event: BillingEvent):
|
||||
"""Emit billing event for processing"""
|
||||
# In a real implementation, this would publish to a message queue
|
||||
# For now, we'll just log it
|
||||
self.logger.debug(f"Emitting billing event: {event}")
|
||||
|
||||
async def _get_existing_invoice(
|
||||
self,
|
||||
tenant_id: str,
|
||||
period_start: datetime,
|
||||
period_end: datetime
|
||||
) -> Optional[Invoice]:
|
||||
"""Check if invoice already exists for period"""
|
||||
|
||||
stmt = select(Invoice).where(
|
||||
and_(
|
||||
Invoice.tenant_id == tenant_id,
|
||||
Invoice.period_start == period_start,
|
||||
Invoice.period_end == period_end
|
||||
)
|
||||
)
|
||||
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def _generate_invoice_number(self, tenant_id: str) -> str:
|
||||
"""Generate unique invoice number"""
|
||||
|
||||
# Get tenant info
|
||||
stmt = select(Tenant).where(Tenant.id == tenant_id)
|
||||
tenant = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not tenant:
|
||||
raise TenantError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
# Generate number: INV-{tenant.slug}-{YYYYMMDD}-{seq}
|
||||
date_str = datetime.utcnow().strftime("%Y%m%d")
|
||||
|
||||
# Get sequence for today
|
||||
seq_key = f"invoice_seq:{tenant_id}:{date_str}"
|
||||
# In a real implementation, use Redis or sequence table
|
||||
# For now, use a simple counter
|
||||
stmt = select(func.count(Invoice.id)).where(
|
||||
and_(
|
||||
Invoice.tenant_id == tenant_id,
|
||||
func.date(Invoice.created_at) == func.current_date()
|
||||
)
|
||||
)
|
||||
seq = self.db.execute(stmt).scalar() + 1
|
||||
|
||||
return f"INV-{tenant.slug}-{date_str}-{seq:04d}"
|
||||
|
||||
async def _apply_credit(self, event: BillingEvent):
|
||||
"""Apply credit to tenant account"""
|
||||
# TODO: Implement credit application
|
||||
pass
|
||||
|
||||
async def _apply_charge(self, event: BillingEvent):
|
||||
"""Apply charge to tenant account"""
|
||||
# TODO: Implement charge application
|
||||
pass
|
||||
|
||||
async def _adjust_quota(self, event: BillingEvent):
|
||||
"""Adjust quota based on billing event"""
|
||||
# TODO: Implement quota adjustment
|
||||
pass
|
||||
|
||||
async def _export_csv(self, records: List[UsageRecord]) -> str:
|
||||
"""Export records to CSV"""
|
||||
import csv
|
||||
import io
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Header
|
||||
writer.writerow([
|
||||
"Timestamp", "Resource Type", "Quantity", "Unit",
|
||||
"Unit Price", "Total Cost", "Currency", "Job ID"
|
||||
])
|
||||
|
||||
# Data rows
|
||||
for record in records:
|
||||
writer.writerow([
|
||||
record.usage_start.isoformat(),
|
||||
record.resource_type,
|
||||
record.quantity,
|
||||
record.unit,
|
||||
record.unit_price,
|
||||
record.total_cost,
|
||||
record.currency,
|
||||
record.job_id or ""
|
||||
])
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
async def _export_json(self, records: List[UsageRecord]) -> str:
|
||||
"""Export records to JSON"""
|
||||
import json
|
||||
|
||||
data = []
|
||||
for record in records:
|
||||
data.append({
|
||||
"timestamp": record.usage_start.isoformat(),
|
||||
"resource_type": record.resource_type,
|
||||
"quantity": float(record.quantity),
|
||||
"unit": record.unit,
|
||||
"unit_price": float(record.unit_price),
|
||||
"total_cost": float(record.total_cost),
|
||||
"currency": record.currency,
|
||||
"job_id": record.job_id,
|
||||
"metadata": record.metadata
|
||||
})
|
||||
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
|
||||
class BillingScheduler:
|
||||
"""Scheduler for automated billing processes"""
|
||||
|
||||
def __init__(self, usage_service: UsageTrackingService):
|
||||
self.usage_service = usage_service
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start billing scheduler"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.logger.info("Billing scheduler started")
|
||||
|
||||
# Schedule daily tasks
|
||||
asyncio.create_task(self._daily_tasks())
|
||||
|
||||
# Schedule monthly invoicing
|
||||
asyncio.create_task(self._monthly_invoicing())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop billing scheduler"""
|
||||
self.running = False
|
||||
self.logger.info("Billing scheduler stopped")
|
||||
|
||||
async def _daily_tasks(self):
|
||||
"""Run daily billing tasks"""
|
||||
while self.running:
|
||||
try:
|
||||
# Reset quotas for new periods
|
||||
await self._reset_daily_quotas()
|
||||
|
||||
# Process pending billing events
|
||||
await self._process_pending_events()
|
||||
|
||||
# Wait until next day
|
||||
now = datetime.utcnow()
|
||||
next_day = (now + timedelta(days=1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
sleep_seconds = (next_day - now).total_seconds()
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in daily tasks: {e}")
|
||||
await asyncio.sleep(3600) # Retry in 1 hour
|
||||
|
||||
async def _monthly_invoicing(self):
|
||||
"""Generate monthly invoices"""
|
||||
while self.running:
|
||||
try:
|
||||
# Wait until first day of month
|
||||
now = datetime.utcnow()
|
||||
if now.day != 1:
|
||||
next_month = now.replace(day=1) + timedelta(days=32)
|
||||
next_month = next_month.replace(day=1)
|
||||
sleep_seconds = (next_month - now).total_seconds()
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
continue
|
||||
|
||||
# Generate invoices for all active tenants
|
||||
await self._generate_monthly_invoices()
|
||||
|
||||
# Wait until next month
|
||||
next_month = now.replace(day=1) + timedelta(days=32)
|
||||
next_month = next_month.replace(day=1)
|
||||
sleep_seconds = (next_month - now).total_seconds()
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in monthly invoicing: {e}")
|
||||
await asyncio.sleep(86400) # Retry in 1 day
|
||||
|
||||
async def _reset_daily_quotas(self):
|
||||
"""Reset daily quotas"""
|
||||
# TODO: Implement daily quota reset
|
||||
pass
|
||||
|
||||
async def _process_pending_events(self):
|
||||
"""Process pending billing events"""
|
||||
# TODO: Implement event processing
|
||||
pass
|
||||
|
||||
async def _generate_monthly_invoices(self):
|
||||
"""Generate invoices for all tenants"""
|
||||
# TODO: Implement monthly invoice generation
|
||||
pass
|
||||
269
apps/coordinator-api/src/app/services/zk_proofs.py
Normal file
269
apps/coordinator-api/src/app/services/zk_proofs.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""
|
||||
ZK Proof generation service for privacy-preserving receipt attestation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from ..models import Receipt, JobResult
|
||||
from ..settings import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ZKProofService:
|
||||
"""Service for generating zero-knowledge proofs for receipts"""
|
||||
|
||||
def __init__(self):
|
||||
self.circuits_dir = Path(__file__).parent.parent.parent.parent / "apps" / "zk-circuits"
|
||||
self.zkey_path = self.circuits_dir / "receipt_0001.zkey"
|
||||
self.wasm_path = self.circuits_dir / "receipt.wasm"
|
||||
self.vkey_path = self.circuits_dir / "verification_key.json"
|
||||
|
||||
# Verify circuit files exist
|
||||
if not all(p.exists() for p in [self.zkey_path, self.wasm_path, self.vkey_path]):
|
||||
logger.warning("ZK circuit files not found. Proof generation disabled.")
|
||||
self.enabled = False
|
||||
else:
|
||||
self.enabled = True
|
||||
|
||||
async def generate_receipt_proof(
|
||||
self,
|
||||
receipt: Receipt,
|
||||
job_result: JobResult,
|
||||
privacy_level: str = "basic"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Generate a ZK proof for a receipt"""
|
||||
|
||||
if not self.enabled:
|
||||
logger.warning("ZK proof generation not available")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Prepare circuit inputs based on privacy level
|
||||
inputs = await self._prepare_inputs(receipt, job_result, privacy_level)
|
||||
|
||||
# Generate proof using snarkjs
|
||||
proof_data = await self._generate_proof(inputs)
|
||||
|
||||
# Return proof with verification data
|
||||
return {
|
||||
"proof": proof_data["proof"],
|
||||
"public_signals": proof_data["publicSignals"],
|
||||
"privacy_level": privacy_level,
|
||||
"circuit_hash": await self._get_circuit_hash()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate ZK proof: {e}")
|
||||
return None
|
||||
|
||||
async def _prepare_inputs(
|
||||
self,
|
||||
receipt: Receipt,
|
||||
job_result: JobResult,
|
||||
privacy_level: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare circuit inputs based on privacy level"""
|
||||
|
||||
if privacy_level == "basic":
|
||||
# Hide computation details, reveal settlement amount
|
||||
return {
|
||||
"data": [
|
||||
str(receipt.job_id),
|
||||
str(receipt.miner_id),
|
||||
str(job_result.result_hash),
|
||||
str(receipt.pricing.rate)
|
||||
],
|
||||
"hash": await self._hash_receipt(receipt)
|
||||
}
|
||||
|
||||
elif privacy_level == "enhanced":
|
||||
# Hide all amounts, prove correctness
|
||||
return {
|
||||
"settlementAmount": receipt.settlement_amount,
|
||||
"timestamp": receipt.timestamp,
|
||||
"receipt": self._serialize_receipt(receipt),
|
||||
"computationResult": job_result.result_hash,
|
||||
"pricingRate": receipt.pricing.rate,
|
||||
"minerReward": receipt.miner_reward,
|
||||
"coordinatorFee": receipt.coordinator_fee
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown privacy level: {privacy_level}")
|
||||
|
||||
async def _hash_receipt(self, receipt: Receipt) -> str:
|
||||
"""Hash receipt for public verification"""
|
||||
# In a real implementation, use Poseidon or the same hash as circuit
|
||||
import hashlib
|
||||
|
||||
receipt_data = {
|
||||
"job_id": receipt.job_id,
|
||||
"miner_id": receipt.miner_id,
|
||||
"timestamp": receipt.timestamp,
|
||||
"pricing": receipt.pricing.dict()
|
||||
}
|
||||
|
||||
receipt_str = json.dumps(receipt_data, sort_keys=True)
|
||||
return hashlib.sha256(receipt_str.encode()).hexdigest()
|
||||
|
||||
def _serialize_receipt(self, receipt: Receipt) -> List[str]:
|
||||
"""Serialize receipt for circuit input"""
|
||||
# Convert receipt to field elements for circuit
|
||||
return [
|
||||
str(receipt.job_id)[:32], # Truncate for field size
|
||||
str(receipt.miner_id)[:32],
|
||||
str(receipt.timestamp)[:32],
|
||||
str(receipt.settlement_amount)[:32],
|
||||
str(receipt.miner_reward)[:32],
|
||||
str(receipt.coordinator_fee)[:32],
|
||||
"0", "0" # Padding
|
||||
]
|
||||
|
||||
async def _generate_proof(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate proof using snarkjs"""
|
||||
|
||||
# Write inputs to temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(inputs, f)
|
||||
inputs_file = f.name
|
||||
|
||||
try:
|
||||
# Create Node.js script for proof generation
|
||||
script = f"""
|
||||
const snarkjs = require('snarkjs');
|
||||
const fs = require('fs');
|
||||
|
||||
async function main() {{
|
||||
try {{
|
||||
// Load inputs
|
||||
const inputs = JSON.parse(fs.readFileSync('{inputs_file}', 'utf8'));
|
||||
|
||||
// Load circuit
|
||||
const wasm = fs.readFileSync('{self.wasm_path}');
|
||||
const zkey = fs.readFileSync('{self.zkey_path}');
|
||||
|
||||
// Calculate witness
|
||||
const {{ witness }} = await snarkjs.wtns.calculate(inputs, wasm, wasm);
|
||||
|
||||
// Generate proof
|
||||
const {{ proof, publicSignals }} = await snarkjs.groth16.prove(zkey, witness);
|
||||
|
||||
// Output result
|
||||
console.log(JSON.stringify({{ proof, publicSignals }}));
|
||||
}} catch (error) {{
|
||||
console.error('Error:', error);
|
||||
process.exit(1);
|
||||
}}
|
||||
}}
|
||||
|
||||
main();
|
||||
"""
|
||||
|
||||
# Write script to temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
|
||||
f.write(script)
|
||||
script_file = f.name
|
||||
|
||||
try:
|
||||
# Run script
|
||||
result = subprocess.run(
|
||||
["node", script_file],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(self.circuits_dir)
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"Proof generation failed: {result.stderr}")
|
||||
|
||||
# Parse result
|
||||
return json.loads(result.stdout)
|
||||
|
||||
finally:
|
||||
os.unlink(script_file)
|
||||
|
||||
finally:
|
||||
os.unlink(inputs_file)
|
||||
|
||||
async def _get_circuit_hash(self) -> str:
|
||||
"""Get hash of circuit for verification"""
|
||||
# In a real implementation, return the hash of the circuit
|
||||
# This ensures the proof is for the correct circuit version
|
||||
return "0x1234567890abcdef"
|
||||
|
||||
async def verify_proof(
|
||||
self,
|
||||
proof: Dict[str, Any],
|
||||
public_signals: List[str]
|
||||
) -> bool:
|
||||
"""Verify a ZK proof"""
|
||||
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load verification key
|
||||
with open(self.vkey_path) as f:
|
||||
vkey = json.load(f)
|
||||
|
||||
# Create verification script
|
||||
script = f"""
|
||||
const snarkjs = require('snarkjs');
|
||||
|
||||
async function main() {{
|
||||
try {{
|
||||
const vKey = {json.dumps(vkey)};
|
||||
const proof = {json.dumps(proof)};
|
||||
const publicSignals = {json.dumps(public_signals)};
|
||||
|
||||
const verified = await snarkjs.groth16.verify(vKey, publicSignals, proof);
|
||||
console.log(verified);
|
||||
}} catch (error) {{
|
||||
console.error('Error:', error);
|
||||
process.exit(1);
|
||||
}}
|
||||
}}
|
||||
|
||||
main();
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
|
||||
f.write(script)
|
||||
script_file = f.name
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", script_file],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(self.circuits_dir)
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Proof verification failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
return result.stdout.strip() == "true"
|
||||
|
||||
finally:
|
||||
os.unlink(script_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify proof: {e}")
|
||||
return False
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if ZK proof generation is available"""
|
||||
return self.enabled
|
||||
|
||||
|
||||
# Global instance
|
||||
zk_proof_service = ZKProofService()
|
||||
505
apps/coordinator-api/tests/test_confidential_transactions.py
Normal file
505
apps/coordinator-api/tests/test_confidential_transactions.py
Normal file
@ -0,0 +1,505 @@
|
||||
"""
|
||||
Tests for confidential transaction functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
|
||||
from app.models import (
|
||||
ConfidentialTransaction,
|
||||
ConfidentialTransactionCreate,
|
||||
ConfidentialAccessRequest,
|
||||
KeyRegistrationRequest
|
||||
)
|
||||
from app.services.encryption import EncryptionService, EncryptedData
|
||||
from app.services.key_management import KeyManager, FileKeyStorage
|
||||
from app.services.access_control import AccessController, PolicyStore
|
||||
from app.services.audit_logging import AuditLogger
|
||||
|
||||
|
||||
class TestEncryptionService:
|
||||
"""Test encryption service functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def key_manager(self):
|
||||
"""Create test key manager"""
|
||||
storage = FileKeyStorage("/tmp/test_keys")
|
||||
return KeyManager(storage)
|
||||
|
||||
@pytest.fixture
|
||||
def encryption_service(self, key_manager):
|
||||
"""Create test encryption service"""
|
||||
return EncryptionService(key_manager)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encrypt_decrypt_success(self, encryption_service, key_manager):
|
||||
"""Test successful encryption and decryption"""
|
||||
# Generate test keys
|
||||
await key_manager.generate_key_pair("client-123")
|
||||
await key_manager.generate_key_pair("miner-456")
|
||||
|
||||
# Test data
|
||||
data = {
|
||||
"amount": "1000",
|
||||
"pricing": {"rate": "0.1", "currency": "AITBC"},
|
||||
"settlement_details": {"method": "crypto", "address": "0x123..."}
|
||||
}
|
||||
|
||||
participants = ["client-123", "miner-456"]
|
||||
|
||||
# Encrypt data
|
||||
encrypted = encryption_service.encrypt(
|
||||
data=data,
|
||||
participants=participants,
|
||||
include_audit=True
|
||||
)
|
||||
|
||||
assert encrypted.ciphertext is not None
|
||||
assert len(encrypted.encrypted_keys) == 3 # 2 participants + audit
|
||||
assert "client-123" in encrypted.encrypted_keys
|
||||
assert "miner-456" in encrypted.encrypted_keys
|
||||
assert "audit" in encrypted.encrypted_keys
|
||||
|
||||
# Decrypt for client
|
||||
decrypted = encryption_service.decrypt(
|
||||
encrypted_data=encrypted,
|
||||
participant_id="client-123",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
assert decrypted == data
|
||||
|
||||
# Decrypt for miner
|
||||
decrypted_miner = encryption_service.decrypt(
|
||||
encrypted_data=encrypted,
|
||||
participant_id="miner-456",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
assert decrypted_miner == data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_decrypt(self, encryption_service, key_manager):
|
||||
"""Test audit decryption"""
|
||||
# Generate keys
|
||||
await key_manager.generate_key_pair("client-123")
|
||||
|
||||
# Create audit authorization
|
||||
auth = await key_manager.create_audit_authorization(
|
||||
issuer="regulator",
|
||||
purpose="compliance"
|
||||
)
|
||||
|
||||
# Encrypt data
|
||||
data = {"amount": "1000", "secret": "hidden"}
|
||||
encrypted = encryption_service.encrypt(
|
||||
data=data,
|
||||
participants=["client-123"],
|
||||
include_audit=True
|
||||
)
|
||||
|
||||
# Decrypt with audit key
|
||||
decrypted = encryption_service.audit_decrypt(
|
||||
encrypted_data=encrypted,
|
||||
audit_authorization=auth,
|
||||
purpose="compliance"
|
||||
)
|
||||
|
||||
assert decrypted == data
|
||||
|
||||
def test_encrypt_no_participants(self, encryption_service):
|
||||
"""Test encryption with no participants"""
|
||||
data = {"test": "data"}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
encryption_service.encrypt(
|
||||
data=data,
|
||||
participants=[],
|
||||
include_audit=True
|
||||
)
|
||||
|
||||
|
||||
class TestKeyManager:
|
||||
"""Test key management functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def key_storage(self, tmp_path):
|
||||
"""Create test key storage"""
|
||||
return FileKeyStorage(str(tmp_path / "keys"))
|
||||
|
||||
@pytest.fixture
|
||||
def key_manager(self, key_storage):
|
||||
"""Create test key manager"""
|
||||
return KeyManager(key_storage)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_key_pair(self, key_manager):
|
||||
"""Test key pair generation"""
|
||||
key_pair = await key_manager.generate_key_pair("test-participant")
|
||||
|
||||
assert key_pair.participant_id == "test-participant"
|
||||
assert key_pair.algorithm == "X25519"
|
||||
assert key_pair.private_key is not None
|
||||
assert key_pair.public_key is not None
|
||||
assert key_pair.version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_rotation(self, key_manager):
|
||||
"""Test key rotation"""
|
||||
# Generate initial key
|
||||
initial_key = await key_manager.generate_key_pair("test-participant")
|
||||
initial_version = initial_key.version
|
||||
|
||||
# Rotate keys
|
||||
new_key = await key_manager.rotate_keys("test-participant")
|
||||
|
||||
assert new_key.participant_id == "test-participant"
|
||||
assert new_key.version > initial_version
|
||||
assert new_key.private_key != initial_key.private_key
|
||||
assert new_key.public_key != initial_key.public_key
|
||||
|
||||
def test_get_public_key(self, key_manager):
|
||||
"""Test retrieving public key"""
|
||||
# This would need a key to be pre-generated
|
||||
with pytest.raises(Exception):
|
||||
key_manager.get_public_key("nonexistent")
|
||||
|
||||
|
||||
class TestAccessController:
|
||||
"""Test access control functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def policy_store(self):
|
||||
"""Create test policy store"""
|
||||
return PolicyStore()
|
||||
|
||||
@pytest.fixture
|
||||
def access_controller(self, policy_store):
|
||||
"""Create test access controller"""
|
||||
return AccessController(policy_store)
|
||||
|
||||
def test_client_access_own_data(self, access_controller):
|
||||
"""Test client accessing own transaction"""
|
||||
request = ConfidentialAccessRequest(
|
||||
transaction_id="tx-123",
|
||||
requester="client-456",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
# Should allow access
|
||||
assert access_controller.verify_access(request) is True
|
||||
|
||||
def test_miner_access_assigned_data(self, access_controller):
|
||||
"""Test miner accessing assigned transaction"""
|
||||
request = ConfidentialAccessRequest(
|
||||
transaction_id="tx-123",
|
||||
requester="miner-789",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
# Should allow access
|
||||
assert access_controller.verify_access(request) is True
|
||||
|
||||
def test_unauthorized_access(self, access_controller):
|
||||
"""Test unauthorized access attempt"""
|
||||
request = ConfidentialAccessRequest(
|
||||
transaction_id="tx-123",
|
||||
requester="unauthorized-user",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
# Should deny access
|
||||
assert access_controller.verify_access(request) is False
|
||||
|
||||
def test_audit_access(self, access_controller):
|
||||
"""Test auditor access"""
|
||||
request = ConfidentialAccessRequest(
|
||||
transaction_id="tx-123",
|
||||
requester="auditor-001",
|
||||
purpose="compliance"
|
||||
)
|
||||
|
||||
# Should allow access during business hours
|
||||
assert access_controller.verify_access(request) is True
|
||||
|
||||
|
||||
class TestAuditLogger:
|
||||
"""Test audit logging functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def audit_logger(self, tmp_path):
|
||||
"""Create test audit logger"""
|
||||
return AuditLogger(log_dir=str(tmp_path / "audit"))
|
||||
|
||||
def test_log_access(self, audit_logger):
|
||||
"""Test logging access events"""
|
||||
# Log access event
|
||||
audit_logger.log_access(
|
||||
participant_id="client-456",
|
||||
transaction_id="tx-123",
|
||||
action="decrypt",
|
||||
outcome="success",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="test-client"
|
||||
)
|
||||
|
||||
# Wait for background writer
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Query logs
|
||||
events = audit_logger.query_logs(
|
||||
participant_id="client-456",
|
||||
limit=10
|
||||
)
|
||||
|
||||
assert len(events) > 0
|
||||
assert events[0].participant_id == "client-456"
|
||||
assert events[0].transaction_id == "tx-123"
|
||||
assert events[0].action == "decrypt"
|
||||
assert events[0].outcome == "success"
|
||||
|
||||
def test_log_key_operation(self, audit_logger):
|
||||
"""Test logging key operations"""
|
||||
audit_logger.log_key_operation(
|
||||
participant_id="miner-789",
|
||||
operation="rotate",
|
||||
key_version=2,
|
||||
outcome="success"
|
||||
)
|
||||
|
||||
# Wait for background writer
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Query logs
|
||||
events = audit_logger.query_logs(
|
||||
event_type="key_operation",
|
||||
limit=10
|
||||
)
|
||||
|
||||
assert len(events) > 0
|
||||
assert events[0].event_type == "key_operation"
|
||||
assert events[0].action == "rotate"
|
||||
assert events[0].details["key_version"] == 2
|
||||
|
||||
def test_export_logs(self, audit_logger):
|
||||
"""Test log export functionality"""
|
||||
# Add some test events
|
||||
audit_logger.log_access(
|
||||
participant_id="test-user",
|
||||
transaction_id="tx-456",
|
||||
action="test",
|
||||
outcome="success"
|
||||
)
|
||||
|
||||
# Wait for background writer
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Export logs
|
||||
export_data = audit_logger.export_logs(
|
||||
start_time=datetime.utcnow() - timedelta(hours=1),
|
||||
end_time=datetime.utcnow(),
|
||||
format="json"
|
||||
)
|
||||
|
||||
# Parse export
|
||||
export = json.loads(export_data)
|
||||
|
||||
assert "export_metadata" in export
|
||||
assert "events" in export
|
||||
assert export["export_metadata"]["event_count"] > 0
|
||||
|
||||
|
||||
class TestConfidentialTransactionAPI:
|
||||
"""Test confidential transaction API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_confidential_transaction(self):
|
||||
"""Test creating a confidential transaction"""
|
||||
from app.routers.confidential import create_confidential_transaction
|
||||
|
||||
request = ConfidentialTransactionCreate(
|
||||
job_id="job-123",
|
||||
amount="1000",
|
||||
pricing={"rate": "0.1"},
|
||||
confidential=True,
|
||||
participants=["client-456", "miner-789"]
|
||||
)
|
||||
|
||||
# Mock API key
|
||||
with patch('app.routers.confidential.get_api_key', return_value="test-key"):
|
||||
response = await create_confidential_transaction(request)
|
||||
|
||||
assert response.transaction_id.startswith("ctx-")
|
||||
assert response.job_id == "job-123"
|
||||
assert response.confidential is True
|
||||
assert response.has_encrypted_data is True
|
||||
assert response.amount is None # Should be encrypted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_access_confidential_data(self):
|
||||
"""Test accessing confidential transaction data"""
|
||||
from app.routers.confidential import access_confidential_data
|
||||
|
||||
request = ConfidentialAccessRequest(
|
||||
transaction_id="tx-123",
|
||||
requester="client-456",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with patch('app.routers.confidential.get_api_key', return_value="test-key"), \
|
||||
patch('app.routers.confidential.get_access_controller') as mock_ac, \
|
||||
patch('app.routers.confidential.get_encryption_service') as mock_es:
|
||||
|
||||
# Mock access control
|
||||
mock_ac.return_value.verify_access.return_value = True
|
||||
|
||||
# Mock encryption service
|
||||
mock_es.return_value.decrypt.return_value = {
|
||||
"amount": "1000",
|
||||
"pricing": {"rate": "0.1"}
|
||||
}
|
||||
|
||||
response = await access_confidential_data(request, "tx-123")
|
||||
|
||||
assert response.success is True
|
||||
assert response.data is not None
|
||||
assert response.data["amount"] == "1000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_key(self):
|
||||
"""Test key registration"""
|
||||
from app.routers.confidential import register_encryption_key
|
||||
|
||||
# Generate test key pair
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
private_key = X25519PrivateKey.generate()
|
||||
public_key = private_key.public_key()
|
||||
public_key_bytes = public_key.public_bytes_raw()
|
||||
|
||||
request = KeyRegistrationRequest(
|
||||
participant_id="test-participant",
|
||||
public_key=base64.b64encode(public_key_bytes).decode()
|
||||
)
|
||||
|
||||
with patch('app.routers.confidential.get_api_key', return_value="test-key"):
|
||||
response = await register_encryption_key(request)
|
||||
|
||||
assert response.success is True
|
||||
assert response.participant_id == "test-participant"
|
||||
assert response.key_version >= 1
|
||||
|
||||
|
||||
# Integration Tests
|
||||
class TestConfidentialTransactionFlow:
|
||||
"""End-to-end tests for confidential transaction flow"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_confidential_flow(self):
|
||||
"""Test complete confidential transaction flow"""
|
||||
# Setup
|
||||
key_storage = FileKeyStorage("/tmp/integration_keys")
|
||||
key_manager = KeyManager(key_storage)
|
||||
encryption_service = EncryptionService(key_manager)
|
||||
access_controller = AccessController(PolicyStore())
|
||||
|
||||
# 1. Generate keys for participants
|
||||
await key_manager.generate_key_pair("client-123")
|
||||
await key_manager.generate_key_pair("miner-456")
|
||||
|
||||
# 2. Create confidential transaction
|
||||
transaction_data = {
|
||||
"amount": "1000",
|
||||
"pricing": {"rate": "0.1", "currency": "AITBC"},
|
||||
"settlement_details": {"method": "crypto"}
|
||||
}
|
||||
|
||||
participants = ["client-123", "miner-456"]
|
||||
|
||||
# 3. Encrypt data
|
||||
encrypted = encryption_service.encrypt(
|
||||
data=transaction_data,
|
||||
participants=participants,
|
||||
include_audit=True
|
||||
)
|
||||
|
||||
# 4. Store transaction (mock)
|
||||
transaction = ConfidentialTransaction(
|
||||
transaction_id="ctx-test-123",
|
||||
job_id="job-456",
|
||||
timestamp=datetime.utcnow(),
|
||||
status="created",
|
||||
confidential=True,
|
||||
participants=participants,
|
||||
encrypted_data=encrypted.to_dict()["ciphertext"],
|
||||
encrypted_keys=encrypted.to_dict()["encrypted_keys"],
|
||||
algorithm=encrypted.algorithm
|
||||
)
|
||||
|
||||
# 5. Client accesses data
|
||||
client_request = ConfidentialAccessRequest(
|
||||
transaction_id=transaction.transaction_id,
|
||||
requester="client-123",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
assert access_controller.verify_access(client_request) is True
|
||||
|
||||
client_data = encryption_service.decrypt(
|
||||
encrypted_data=encrypted,
|
||||
participant_id="client-123",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
assert client_data == transaction_data
|
||||
|
||||
# 6. Miner accesses data
|
||||
miner_request = ConfidentialAccessRequest(
|
||||
transaction_id=transaction.transaction_id,
|
||||
requester="miner-456",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
assert access_controller.verify_access(miner_request) is True
|
||||
|
||||
miner_data = encryption_service.decrypt(
|
||||
encrypted_data=encrypted,
|
||||
participant_id="miner-456",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
assert miner_data == transaction_data
|
||||
|
||||
# 7. Unauthorized access denied
|
||||
unauthorized_request = ConfidentialAccessRequest(
|
||||
transaction_id=transaction.transaction_id,
|
||||
requester="unauthorized",
|
||||
purpose="settlement"
|
||||
)
|
||||
|
||||
assert access_controller.verify_access(unauthorized_request) is False
|
||||
|
||||
# 8. Audit access
|
||||
audit_auth = await key_manager.create_audit_authorization(
|
||||
issuer="regulator",
|
||||
purpose="compliance"
|
||||
)
|
||||
|
||||
audit_data = encryption_service.audit_decrypt(
|
||||
encrypted_data=encrypted,
|
||||
audit_authorization=audit_auth,
|
||||
purpose="compliance"
|
||||
)
|
||||
|
||||
assert audit_data == transaction_data
|
||||
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree("/tmp/integration_keys", ignore_errors=True)
|
||||
402
apps/coordinator-api/tests/test_zk_proofs.py
Normal file
402
apps/coordinator-api/tests/test_zk_proofs.py
Normal file
@ -0,0 +1,402 @@
|
||||
"""
|
||||
Tests for ZK proof generation and verification
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.zk_proofs import ZKProofService
|
||||
from app.models import JobReceipt, Job, JobResult
|
||||
from app.domain import ReceiptPayload
|
||||
|
||||
|
||||
class TestZKProofService:
|
||||
"""Test cases for ZK proof service"""
|
||||
|
||||
@pytest.fixture
|
||||
def zk_service(self):
|
||||
"""Create ZK proof service instance"""
|
||||
with patch('app.services.zk_proofs.settings'):
|
||||
service = ZKProofService()
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def sample_job(self):
|
||||
"""Create sample job for testing"""
|
||||
return Job(
|
||||
id="test-job-123",
|
||||
client_id="client-456",
|
||||
payload={"type": "test"},
|
||||
constraints={},
|
||||
requested_at=None,
|
||||
completed=True
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_job_result(self):
|
||||
"""Create sample job result"""
|
||||
return {
|
||||
"result": "test-result",
|
||||
"result_hash": "0x1234567890abcdef",
|
||||
"units": 100,
|
||||
"unit_type": "gpu_seconds",
|
||||
"metrics": {"execution_time": 5.0}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_receipt(self, sample_job):
|
||||
"""Create sample receipt"""
|
||||
payload = ReceiptPayload(
|
||||
version="1.0",
|
||||
receipt_id="receipt-789",
|
||||
job_id=sample_job.id,
|
||||
provider="miner-001",
|
||||
client=sample_job.client_id,
|
||||
units=100,
|
||||
unit_type="gpu_seconds",
|
||||
price="0.1",
|
||||
started_at=1640995200,
|
||||
completed_at=1640995800,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
return JobReceipt(
|
||||
job_id=sample_job.id,
|
||||
receipt_id=payload.receipt_id,
|
||||
payload=payload.dict()
|
||||
)
|
||||
|
||||
def test_service_initialization_with_files(self):
|
||||
"""Test service initialization when circuit files exist"""
|
||||
with patch('app.services.zk_proofs.Path') as mock_path:
|
||||
# Mock file existence
|
||||
mock_path.return_value.exists.return_value = True
|
||||
|
||||
service = ZKProofService()
|
||||
assert service.enabled is True
|
||||
|
||||
def test_service_initialization_without_files(self):
|
||||
"""Test service initialization when circuit files are missing"""
|
||||
with patch('app.services.zk_proofs.Path') as mock_path:
|
||||
# Mock file non-existence
|
||||
mock_path.return_value.exists.return_value = False
|
||||
|
||||
service = ZKProofService()
|
||||
assert service.enabled is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_proof_basic_privacy(self, zk_service, sample_receipt, sample_job_result):
|
||||
"""Test generating proof with basic privacy level"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
# Mock subprocess calls
|
||||
with patch('subprocess.run') as mock_run:
|
||||
# Mock successful proof generation
|
||||
mock_run.return_value.returncode = 0
|
||||
mock_run.return_value.stdout = json.dumps({
|
||||
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
"publicSignals": ["0x1234", "1000", "1640995800"]
|
||||
})
|
||||
|
||||
# Generate proof
|
||||
proof = await zk_service.generate_receipt_proof(
|
||||
receipt=sample_receipt,
|
||||
job_result=sample_job_result,
|
||||
privacy_level="basic"
|
||||
)
|
||||
|
||||
assert proof is not None
|
||||
assert "proof" in proof
|
||||
assert "public_signals" in proof
|
||||
assert proof["privacy_level"] == "basic"
|
||||
assert "circuit_hash" in proof
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_proof_enhanced_privacy(self, zk_service, sample_receipt, sample_job_result):
|
||||
"""Test generating proof with enhanced privacy level"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
with patch('subprocess.run') as mock_run:
|
||||
mock_run.return_value.returncode = 0
|
||||
mock_run.return_value.stdout = json.dumps({
|
||||
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
"publicSignals": ["1000", "1640995800"]
|
||||
})
|
||||
|
||||
proof = await zk_service.generate_receipt_proof(
|
||||
receipt=sample_receipt,
|
||||
job_result=sample_job_result,
|
||||
privacy_level="enhanced"
|
||||
)
|
||||
|
||||
assert proof is not None
|
||||
assert proof["privacy_level"] == "enhanced"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_proof_service_disabled(self, zk_service, sample_receipt, sample_job_result):
|
||||
"""Test proof generation when service is disabled"""
|
||||
zk_service.enabled = False
|
||||
|
||||
proof = await zk_service.generate_receipt_proof(
|
||||
receipt=sample_receipt,
|
||||
job_result=sample_job_result,
|
||||
privacy_level="basic"
|
||||
)
|
||||
|
||||
assert proof is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_proof_invalid_privacy_level(self, zk_service, sample_receipt, sample_job_result):
|
||||
"""Test proof generation with invalid privacy level"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown privacy level"):
|
||||
await zk_service.generate_receipt_proof(
|
||||
receipt=sample_receipt,
|
||||
job_result=sample_job_result,
|
||||
privacy_level="invalid"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_proof_success(self, zk_service):
|
||||
"""Test successful proof verification"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
with patch('subprocess.run') as mock_run, \
|
||||
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
|
||||
|
||||
mock_run.return_value.returncode = 0
|
||||
mock_run.return_value.stdout = "true"
|
||||
|
||||
result = await zk_service.verify_proof(
|
||||
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
public_signals=["0x1234", "1000"]
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_proof_failure(self, zk_service):
|
||||
"""Test proof verification failure"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
with patch('subprocess.run') as mock_run, \
|
||||
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
|
||||
|
||||
mock_run.return_value.returncode = 1
|
||||
mock_run.return_value.stderr = "Verification failed"
|
||||
|
||||
result = await zk_service.verify_proof(
|
||||
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
public_signals=["0x1234", "1000"]
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_proof_service_disabled(self, zk_service):
|
||||
"""Test proof verification when service is disabled"""
|
||||
zk_service.enabled = False
|
||||
|
||||
result = await zk_service.verify_proof(
|
||||
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
public_signals=["0x1234", "1000"]
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_hash_receipt(self, zk_service, sample_receipt):
|
||||
"""Test receipt hashing"""
|
||||
receipt_hash = zk_service._hash_receipt(sample_receipt)
|
||||
|
||||
assert isinstance(receipt_hash, str)
|
||||
assert len(receipt_hash) == 64 # SHA256 hex length
|
||||
assert all(c in '0123456789abcdef' for c in receipt_hash)
|
||||
|
||||
def test_serialize_receipt(self, zk_service, sample_receipt):
|
||||
"""Test receipt serialization for circuit"""
|
||||
serialized = zk_service._serialize_receipt(sample_receipt)
|
||||
|
||||
assert isinstance(serialized, list)
|
||||
assert len(serialized) == 8
|
||||
assert all(isinstance(x, str) for x in serialized)
|
||||
|
||||
|
||||
class TestZKProofIntegration:
|
||||
"""Integration tests for ZK proof system"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receipt_creation_with_zk_proof(self):
|
||||
"""Test receipt creation with ZK proof generation"""
|
||||
from app.services.receipts import ReceiptService
|
||||
from sqlmodel import Session
|
||||
|
||||
# Create mock session
|
||||
session = Mock(spec=Session)
|
||||
|
||||
# Create receipt service
|
||||
receipt_service = ReceiptService(session)
|
||||
|
||||
# Create sample job
|
||||
job = Job(
|
||||
id="test-job-123",
|
||||
client_id="client-456",
|
||||
payload={"type": "test"},
|
||||
constraints={},
|
||||
requested_at=None,
|
||||
completed=True
|
||||
)
|
||||
|
||||
# Mock ZK proof service
|
||||
with patch('app.services.receipts.zk_proof_service') as mock_zk:
|
||||
mock_zk.is_enabled.return_value = True
|
||||
mock_zk.generate_receipt_proof = AsyncMock(return_value={
|
||||
"proof": {"a": ["1", "2"]},
|
||||
"public_signals": ["0x1234"],
|
||||
"privacy_level": "basic"
|
||||
})
|
||||
|
||||
# Create receipt with privacy
|
||||
receipt = await receipt_service.create_receipt(
|
||||
job=job,
|
||||
miner_id="miner-001",
|
||||
job_result={"result": "test"},
|
||||
result_metrics={"units": 100},
|
||||
privacy_level="basic"
|
||||
)
|
||||
|
||||
assert receipt is not None
|
||||
assert "zk_proof" in receipt
|
||||
assert receipt["privacy_level"] == "basic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settlement_with_zk_proof(self):
|
||||
"""Test cross-chain settlement with ZK proof"""
|
||||
from aitbc.settlement.hooks import SettlementHook
|
||||
from aitbc.settlement.manager import BridgeManager
|
||||
|
||||
# Create mock bridge manager
|
||||
bridge_manager = Mock(spec=BridgeManager)
|
||||
|
||||
# Create settlement hook
|
||||
settlement_hook = SettlementHook(bridge_manager)
|
||||
|
||||
# Create sample job with ZK proof
|
||||
job = Job(
|
||||
id="test-job-123",
|
||||
client_id="client-456",
|
||||
payload={"type": "test"},
|
||||
constraints={},
|
||||
requested_at=None,
|
||||
completed=True,
|
||||
target_chain=2
|
||||
)
|
||||
|
||||
# Create receipt with ZK proof
|
||||
receipt_payload = {
|
||||
"version": "1.0",
|
||||
"receipt_id": "receipt-789",
|
||||
"job_id": job.id,
|
||||
"provider": "miner-001",
|
||||
"client": job.client_id,
|
||||
"zk_proof": {
|
||||
"proof": {"a": ["1", "2"]},
|
||||
"public_signals": ["0x1234"]
|
||||
}
|
||||
}
|
||||
|
||||
job.receipt = JobReceipt(
|
||||
job_id=job.id,
|
||||
receipt_id=receipt_payload["receipt_id"],
|
||||
payload=receipt_payload
|
||||
)
|
||||
|
||||
# Test settlement message creation
|
||||
message = await settlement_hook._create_settlement_message(
|
||||
job,
|
||||
options={"use_zk_proof": True, "privacy_level": "basic"}
|
||||
)
|
||||
|
||||
assert message.zk_proof is not None
|
||||
assert message.privacy_level == "basic"
|
||||
|
||||
|
||||
# Helper function for mocking file operations
|
||||
def mock_open(read_data=""):
|
||||
"""Mock open function for file operations"""
|
||||
from unittest.mock import mock_open
|
||||
return mock_open(read_data=read_data)
|
||||
|
||||
|
||||
# Benchmark tests
|
||||
class TestZKProofPerformance:
|
||||
"""Performance benchmarks for ZK proof operations"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proof_generation_time(self):
|
||||
"""Benchmark proof generation time"""
|
||||
import time
|
||||
|
||||
if not Path("apps/zk-circuits/receipt.wasm").exists():
|
||||
pytest.skip("ZK circuits not built")
|
||||
|
||||
service = ZKProofService()
|
||||
if not service.enabled:
|
||||
pytest.skip("ZK service not enabled")
|
||||
|
||||
# Create test data
|
||||
receipt = JobReceipt(
|
||||
job_id="benchmark-job",
|
||||
receipt_id="benchmark-receipt",
|
||||
payload={"test": "data"}
|
||||
)
|
||||
|
||||
job_result = {"result": "benchmark"}
|
||||
|
||||
# Measure proof generation time
|
||||
start_time = time.time()
|
||||
proof = await service.generate_receipt_proof(
|
||||
receipt=receipt,
|
||||
job_result=job_result,
|
||||
privacy_level="basic"
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
generation_time = end_time - start_time
|
||||
|
||||
assert proof is not None
|
||||
assert generation_time < 30 # Should complete within 30 seconds
|
||||
|
||||
print(f"Proof generation time: {generation_time:.2f} seconds")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proof_verification_time(self):
|
||||
"""Benchmark proof verification time"""
|
||||
import time
|
||||
|
||||
service = ZKProofService()
|
||||
if not service.enabled:
|
||||
pytest.skip("ZK service not enabled")
|
||||
|
||||
# Create test proof
|
||||
proof = {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]}
|
||||
public_signals = ["0x1234", "1000"]
|
||||
|
||||
# Measure verification time
|
||||
start_time = time.time()
|
||||
result = await service.verify_proof(proof, public_signals)
|
||||
end_time = time.time()
|
||||
|
||||
verification_time = end_time - start_time
|
||||
|
||||
assert isinstance(result, bool)
|
||||
assert verification_time < 1 # Should complete within 1 second
|
||||
|
||||
print(f"Proof verification time: {verification_time:.3f} seconds")
|
||||
Reference in New Issue
Block a user