feat: add marketplace metrics, privacy features, and service registry endpoints

- Add Prometheus metrics for marketplace API throughput and error rates with new dashboard panels
- Implement confidential transaction models with encryption support and access control
- Add key management system with registration, rotation, and audit logging
- Create services and registry routers for service discovery and management
- Integrate ZK proof generation for privacy-preserving receipts
- Add metrics instru
This commit is contained in:
oib
2025-12-22 10:33:23 +01:00
parent d98b2c7772
commit c8be9d7414
260 changed files with 59033 additions and 351 deletions

View File

@ -0,0 +1,406 @@
"""
API endpoints for cross-chain settlements
"""
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from pydantic import BaseModel, Field
import asyncio
from ...settlement.hooks import SettlementHook
from ...settlement.manager import BridgeManager
from ...settlement.bridges.base import SettlementResult
from ...auth import get_api_key
from ...models.job import Job
router = APIRouter(prefix="/settlement", tags=["settlement"])
class CrossChainSettlementRequest(BaseModel):
"""Request model for cross-chain settlement"""
job_id: str = Field(..., description="ID of the job to settle")
target_chain_id: int = Field(..., description="Target blockchain chain ID")
bridge_name: Optional[str] = Field(None, description="Specific bridge to use")
priority: str = Field("cost", description="Settlement priority: 'cost' or 'speed'")
privacy_level: Optional[str] = Field(None, description="Privacy level: 'basic' or 'enhanced'")
use_zk_proof: bool = Field(False, description="Use zero-knowledge proof for privacy")
class SettlementEstimateRequest(BaseModel):
"""Request model for settlement cost estimation"""
job_id: str = Field(..., description="ID of the job")
target_chain_id: int = Field(..., description="Target blockchain chain ID")
bridge_name: Optional[str] = Field(None, description="Specific bridge to use")
class BatchSettlementRequest(BaseModel):
"""Request model for batch settlement"""
job_ids: List[str] = Field(..., description="List of job IDs to settle")
target_chain_id: int = Field(..., description="Target blockchain chain ID")
bridge_name: Optional[str] = Field(None, description="Specific bridge to use")
class SettlementResponse(BaseModel):
"""Response model for settlement operations"""
message_id: str = Field(..., description="Settlement message ID")
status: str = Field(..., description="Settlement status")
transaction_hash: Optional[str] = Field(None, description="Transaction hash")
bridge_name: str = Field(..., description="Bridge used")
estimated_completion: Optional[str] = Field(None, description="Estimated completion time")
error_message: Optional[str] = Field(None, description="Error message if failed")
class CostEstimateResponse(BaseModel):
"""Response model for cost estimates"""
bridge_costs: Dict[str, Dict[str, Any]] = Field(..., description="Costs by bridge")
recommended_bridge: str = Field(..., description="Recommended bridge")
total_estimates: Dict[str, float] = Field(..., description="Min/Max/Average costs")
def get_settlement_hook() -> SettlementHook:
"""Dependency injection for settlement hook"""
# This would be properly injected in the app setup
from ...main import settlement_hook
return settlement_hook
def get_bridge_manager() -> BridgeManager:
"""Dependency injection for bridge manager"""
# This would be properly injected in the app setup
from ...main import bridge_manager
return bridge_manager
@router.post("/cross-chain", response_model=SettlementResponse)
async def initiate_cross_chain_settlement(
request: CrossChainSettlementRequest,
background_tasks: BackgroundTasks,
settlement_hook: SettlementHook = Depends(get_settlement_hook)
):
"""
Initiate cross-chain settlement for a completed job
This endpoint settles job receipts and payments across different blockchains
using various bridge protocols (LayerZero, Chainlink CCIP, etc.).
"""
try:
# Validate job exists and is completed
job = await Job.get(request.job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if not job.completed:
raise HTTPException(status_code=400, detail="Job is not completed")
if job.cross_chain_settlement_id:
raise HTTPException(
status_code=409,
detail=f"Job already has settlement {job.cross_chain_settlement_id}"
)
# Initiate settlement
settlement_options = {}
if request.use_zk_proof:
settlement_options["privacy_level"] = request.privacy_level or "basic"
settlement_options["use_zk_proof"] = True
result = await settlement_hook.initiate_manual_settlement(
job_id=request.job_id,
target_chain_id=request.target_chain_id,
bridge_name=request.bridge_name,
options=settlement_options
)
# Add background task to monitor settlement
background_tasks.add_task(
monitor_settlement_completion,
result.message_id,
request.job_id
)
return SettlementResponse(
message_id=result.message_id,
status=result.status.value,
transaction_hash=result.transaction_hash,
bridge_name=result.transaction_hash and await get_bridge_from_tx(result.transaction_hash),
estimated_completion=estimate_completion_time(result.status),
error_message=result.error_message
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Settlement failed: {str(e)}")
@router.get("/{message_id}/status", response_model=SettlementResponse)
async def get_settlement_status(
message_id: str,
settlement_hook: SettlementHook = Depends(get_settlement_hook)
):
"""Get the current status of a cross-chain settlement"""
try:
result = await settlement_hook.get_settlement_status(message_id)
# Get job info if available
job_id = None
if result.transaction_hash:
job_id = await get_job_id_from_settlement(message_id)
return SettlementResponse(
message_id=message_id,
status=result.status.value,
transaction_hash=result.transaction_hash,
bridge_name=job_id and await get_bridge_from_job(job_id),
estimated_completion=estimate_completion_time(result.status),
error_message=result.error_message
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}")
@router.post("/estimate-cost", response_model=CostEstimateResponse)
async def estimate_settlement_cost(
request: SettlementEstimateRequest,
settlement_hook: SettlementHook = Depends(get_settlement_hook)
):
"""Estimate the cost of cross-chain settlement"""
try:
# Get cost estimates
estimates = await settlement_hook.estimate_settlement_cost(
job_id=request.job_id,
target_chain_id=request.target_chain_id,
bridge_name=request.bridge_name
)
# Calculate totals and recommendations
valid_estimates = {
name: cost for name, cost in estimates.items()
if 'error' not in cost
}
if not valid_estimates:
raise HTTPException(
status_code=400,
detail="No bridges available for this settlement"
)
# Find cheapest option
cheapest_bridge = min(valid_estimates.items(), key=lambda x: x[1]['total'])
# Calculate statistics
costs = [est['total'] for est in valid_estimates.values()]
total_estimates = {
"min": min(costs),
"max": max(costs),
"average": sum(costs) / len(costs)
}
return CostEstimateResponse(
bridge_costs=estimates,
recommended_bridge=cheapest_bridge[0],
total_estimates=total_estimates
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Estimation failed: {str(e)}")
@router.post("/batch", response_model=List[SettlementResponse])
async def batch_settle(
request: BatchSettlementRequest,
background_tasks: BackgroundTasks,
settlement_hook: SettlementHook = Depends(get_settlement_hook)
):
"""Settle multiple jobs in a batch"""
try:
# Validate all jobs exist and are completed
jobs = []
for job_id in request.job_ids:
job = await Job.get(job_id)
if not job:
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
if not job.completed:
raise HTTPException(
status_code=400,
detail=f"Job {job_id} is not completed"
)
jobs.append(job)
# Process batch settlement
results = []
for job in jobs:
try:
result = await settlement_hook.initiate_manual_settlement(
job_id=job.id,
target_chain_id=request.target_chain_id,
bridge_name=request.bridge_name
)
# Add monitoring task
background_tasks.add_task(
monitor_settlement_completion,
result.message_id,
job.id
)
results.append(SettlementResponse(
message_id=result.message_id,
status=result.status.value,
transaction_hash=result.transaction_hash,
bridge_name=result.transaction_hash and await get_bridge_from_tx(result.transaction_hash),
estimated_completion=estimate_completion_time(result.status),
error_message=result.error_message
))
except Exception as e:
results.append(SettlementResponse(
message_id="",
status="failed",
transaction_hash=None,
bridge_name="",
estimated_completion=None,
error_message=str(e)
))
return results
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch settlement failed: {str(e)}")
@router.get("/bridges", response_model=Dict[str, Any])
async def list_supported_bridges(
settlement_hook: SettlementHook = Depends(get_settlement_hook)
):
"""List all supported bridges and their capabilities"""
try:
return await settlement_hook.list_supported_bridges()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list bridges: {str(e)}")
@router.get("/chains", response_model=Dict[str, List[int]])
async def list_supported_chains(
settlement_hook: SettlementHook = Depends(get_settlement_hook)
):
"""List all supported chains by bridge"""
try:
return await settlement_hook.list_supported_chains()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list chains: {str(e)}")
@router.post("/{message_id}/refund")
async def refund_settlement(
message_id: str,
bridge_manager: BridgeManager = Depends(get_bridge_manager)
):
"""Attempt to refund a failed settlement"""
try:
result = await bridge_manager.refund_failed_settlement(message_id)
return {
"message_id": message_id,
"status": result.status.value,
"refund_transaction": result.transaction_hash,
"error_message": result.error_message
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Refund failed: {str(e)}")
@router.get("/job/{job_id}/settlements")
async def get_job_settlements(
job_id: str,
bridge_manager: BridgeManager = Depends(get_bridge_manager)
):
"""Get all cross-chain settlements for a job"""
try:
# Validate job exists
job = await Job.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# Get settlements from storage
settlements = await bridge_manager.storage.get_settlements_by_job(job_id)
return {
"job_id": job_id,
"settlements": settlements,
"total_count": len(settlements)
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get settlements: {str(e)}")
# Helper functions
async def monitor_settlement_completion(message_id: str, job_id: str):
"""Background task to monitor settlement completion"""
settlement_hook = get_settlement_hook()
# Monitor for up to 1 hour
max_wait = 3600
start_time = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start_time < max_wait:
result = await settlement_hook.get_settlement_status(message_id)
# Update job status
job = await Job.get(job_id)
if job:
job.cross_chain_settlement_status = result.status.value
await job.save()
# If completed or failed, stop monitoring
if result.status.value in ['completed', 'failed']:
break
# Wait before checking again
await asyncio.sleep(30)
def estimate_completion_time(status) -> Optional[str]:
"""Estimate completion time based on status"""
if status.value == 'completed':
return None
elif status.value == 'pending':
return "5-10 minutes"
elif status.value == 'in_progress':
return "2-5 minutes"
else:
return None
async def get_bridge_from_tx(tx_hash: str) -> str:
"""Get bridge name from transaction hash"""
# This would look up the bridge from the transaction
# For now, return placeholder
return "layerzero"
async def get_bridge_from_job(job_id: str) -> str:
"""Get bridge name from job"""
# This would look up the bridge from the job
# For now, return placeholder
return "layerzero"
async def get_job_id_from_settlement(message_id: str) -> Optional[str]:
"""Get job ID from settlement message ID"""
# This would look up the job ID from storage
# For now, return None
return None

View File

@ -0,0 +1,21 @@
"""
Cross-chain settlement module for AITBC
"""
from .manager import BridgeManager
from .hooks import SettlementHook, BatchSettlementHook, SettlementMonitor
from .storage import SettlementStorage, InMemorySettlementStorage
from .bridges.base import BridgeAdapter, BridgeConfig, SettlementMessage, SettlementResult
__all__ = [
"BridgeManager",
"SettlementHook",
"BatchSettlementHook",
"SettlementMonitor",
"SettlementStorage",
"InMemorySettlementStorage",
"BridgeAdapter",
"BridgeConfig",
"SettlementMessage",
"SettlementResult",
]

View File

@ -0,0 +1,23 @@
"""
Bridge adapters for cross-chain settlements
"""
from .base import (
BridgeAdapter,
BridgeConfig,
SettlementMessage,
SettlementResult,
BridgeStatus,
BridgeError
)
from .layerzero import LayerZeroAdapter
__all__ = [
"BridgeAdapter",
"BridgeConfig",
"SettlementMessage",
"SettlementResult",
"BridgeStatus",
"BridgeError",
"LayerZeroAdapter",
]

View File

@ -0,0 +1,172 @@
"""
Base interfaces for cross-chain settlement bridges
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from enum import Enum
import json
from datetime import datetime
class BridgeStatus(Enum):
"""Bridge operation status"""
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
REFUNDED = "refunded"
@dataclass
class BridgeConfig:
"""Bridge configuration"""
name: str
enabled: bool
endpoint_address: str
supported_chains: List[int]
default_fee: str
max_message_size: int
timeout: int = 3600
@dataclass
class SettlementMessage:
"""Message to be settled across chains"""
source_chain_id: int
target_chain_id: int
job_id: str
receipt_hash: str
proof_data: Dict[str, Any]
payment_amount: int
payment_token: str
nonce: int
signature: str
gas_limit: Optional[int] = None
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
@dataclass
class SettlementResult:
"""Result of settlement operation"""
message_id: str
status: BridgeStatus
transaction_hash: Optional[str] = None
error_message: Optional[str] = None
gas_used: Optional[int] = None
fee_paid: Optional[int] = None
created_at: datetime = None
completed_at: Optional[datetime] = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
class BridgeAdapter(ABC):
"""Abstract interface for bridge adapters"""
def __init__(self, config: BridgeConfig):
self.config = config
self.name = config.name
@abstractmethod
async def initialize(self) -> None:
"""Initialize the bridge adapter"""
pass
@abstractmethod
async def send_message(self, message: SettlementMessage) -> SettlementResult:
"""Send message to target chain"""
pass
@abstractmethod
async def verify_delivery(self, message_id: str) -> bool:
"""Verify message was delivered"""
pass
@abstractmethod
async def get_message_status(self, message_id: str) -> SettlementResult:
"""Get current status of message"""
pass
@abstractmethod
async def estimate_cost(self, message: SettlementMessage) -> Dict[str, int]:
"""Estimate bridge fees"""
pass
@abstractmethod
async def refund_failed_message(self, message_id: str) -> SettlementResult:
"""Refund failed message if supported"""
pass
def get_supported_chains(self) -> List[int]:
"""Get list of supported target chains"""
return self.config.supported_chains
def get_max_message_size(self) -> int:
"""Get maximum message size in bytes"""
return self.config.max_message_size
async def validate_message(self, message: SettlementMessage) -> bool:
"""Validate message before sending"""
# Check if target chain is supported
if message.target_chain_id not in self.get_supported_chains():
raise ValueError(f"Chain {message.target_chain_id} not supported")
# Check message size
message_size = len(json.dumps(message.proof_data).encode())
if message_size > self.get_max_message_size():
raise ValueError(f"Message too large: {message_size} > {self.get_max_message_size()}")
# Validate signature
if not await self._verify_signature(message):
raise ValueError("Invalid signature")
return True
async def _verify_signature(self, message: SettlementMessage) -> bool:
"""Verify message signature - to be implemented by subclass"""
# This would verify the cryptographic signature
# Implementation depends on the signature scheme used
return True
def _encode_payload(self, message: SettlementMessage) -> bytes:
"""Encode message payload - to be implemented by subclass"""
# Each bridge may have different encoding requirements
raise NotImplementedError("Subclass must implement _encode_payload")
async def _get_gas_estimate(self, message: SettlementMessage) -> int:
"""Get gas estimate for message - to be implemented by subclass"""
# Each bridge has different gas requirements
raise NotImplementedError("Subclass must implement _get_gas_estimate")
class BridgeError(Exception):
"""Base exception for bridge errors"""
pass
class BridgeNotSupportedError(BridgeError):
"""Raised when operation is not supported by bridge"""
pass
class BridgeTimeoutError(BridgeError):
"""Raised when bridge operation times out"""
pass
class BridgeInsufficientFundsError(BridgeError):
"""Raised when insufficient funds for bridge operation"""
pass
class BridgeMessageTooLargeError(BridgeError):
"""Raised when message exceeds bridge limits"""
pass

View File

@ -0,0 +1,288 @@
"""
LayerZero bridge adapter implementation
"""
from typing import Dict, Any, List, Optional
import json
import asyncio
from web3 import Web3
from web3.contract import Contract
from eth_utils import to_checksum_address, encode_hex
from .base import (
BridgeAdapter,
BridgeConfig,
SettlementMessage,
SettlementResult,
BridgeStatus,
BridgeError,
BridgeTimeoutError,
BridgeInsufficientFundsError
)
class LayerZeroAdapter(BridgeAdapter):
"""LayerZero bridge adapter for cross-chain settlements"""
# LayerZero chain IDs
CHAIN_IDS = {
1: 101, # Ethereum
137: 109, # Polygon
56: 102, # BSC
42161: 110, # Arbitrum
10: 111, # Optimism
43114: 106 # Avalanche
}
def __init__(self, config: BridgeConfig, web3: Web3):
super().__init__(config)
self.web3 = web3
self.endpoint: Optional[Contract] = None
self.ultra_light_node: Optional[Contract] = None
async def initialize(self) -> None:
"""Initialize LayerZero contracts"""
# Load LayerZero endpoint ABI
endpoint_abi = await self._load_abi("LayerZeroEndpoint")
self.endpoint = self.web3.eth.contract(
address=to_checksum_address(self.config.endpoint_address),
abi=endpoint_abi
)
# Load Ultra Light Node ABI for fee estimation
uln_abi = await self._load_abi("UltraLightNode")
uln_address = await self.endpoint.functions.ultraLightNode().call()
self.ultra_light_node = self.web3.eth.contract(
address=to_checksum_address(uln_address),
abi=uln_abi
)
async def send_message(self, message: SettlementMessage) -> SettlementResult:
"""Send message via LayerZero"""
try:
# Validate message
await self.validate_message(message)
# Get target address on destination chain
target_address = await self._get_target_address(message.target_chain_id)
# Encode payload
payload = self._encode_payload(message)
# Estimate fees
fees = await self.estimate_cost(message)
# Get gas limit
gas_limit = message.gas_limit or await self._get_gas_estimate(message)
# Build transaction
tx_params = {
'from': await self._get_signer_address(),
'gas': gas_limit,
'value': fees['layerZeroFee'],
'nonce': await self.web3.eth.get_transaction_count(
await self._get_signer_address()
)
}
# Send transaction
tx_hash = await self.endpoint.functions.send(
self.CHAIN_IDS[message.target_chain_id], # dstChainId
target_address, # destination address
payload, # payload
message.payment_amount, # value (optional)
[0, 0, 0], # address and parameters for adapterParams
message.nonce # refund address
).transact(tx_params)
# Wait for confirmation
receipt = await self.web3.eth.wait_for_transaction_receipt(tx_hash)
return SettlementResult(
message_id=tx_hash.hex(),
status=BridgeStatus.IN_PROGRESS,
transaction_hash=tx_hash.hex(),
gas_used=receipt.gasUsed,
fee_paid=fees['layerZeroFee']
)
except Exception as e:
return SettlementResult(
message_id="",
status=BridgeStatus.FAILED,
error_message=str(e)
)
async def verify_delivery(self, message_id: str) -> bool:
"""Verify message was delivered"""
try:
# Get transaction receipt
receipt = await self.web3.eth.get_transaction_receipt(message_id)
# Check for Delivered event
delivered_logs = self.endpoint.events.Delivered().processReceipt(receipt)
return len(delivered_logs) > 0
except Exception:
return False
async def get_message_status(self, message_id: str) -> SettlementResult:
"""Get current status of message"""
try:
# Get transaction receipt
receipt = await self.web3.eth.get_transaction_receipt(message_id)
if receipt.status == 0:
return SettlementResult(
message_id=message_id,
status=BridgeStatus.FAILED,
transaction_hash=message_id,
completed_at=receipt['blockTimestamp']
)
# Check if delivered
if await self.verify_delivery(message_id):
return SettlementResult(
message_id=message_id,
status=BridgeStatus.COMPLETED,
transaction_hash=message_id,
completed_at=receipt['blockTimestamp']
)
# Still in progress
return SettlementResult(
message_id=message_id,
status=BridgeStatus.IN_PROGRESS,
transaction_hash=message_id
)
except Exception as e:
return SettlementResult(
message_id=message_id,
status=BridgeStatus.FAILED,
error_message=str(e)
)
async def estimate_cost(self, message: SettlementMessage) -> Dict[str, int]:
"""Estimate LayerZero fees"""
try:
# Get destination chain ID
dst_chain_id = self.CHAIN_IDS[message.target_chain_id]
# Get target address
target_address = await self._get_target_address(message.target_chain_id)
# Encode payload
payload = self._encode_payload(message)
# Estimate fee using LayerZero endpoint
(native_fee, zro_fee) = await self.endpoint.functions.estimateFees(
dst_chain_id,
target_address,
payload,
False, # payInZRO
[0, 0, 0] # adapterParams
).call()
return {
'layerZeroFee': native_fee,
'zroFee': zro_fee,
'total': native_fee + zro_fee
}
except Exception as e:
raise BridgeError(f"Failed to estimate fees: {str(e)}")
async def refund_failed_message(self, message_id: str) -> SettlementResult:
"""LayerZero doesn't support direct refunds"""
raise BridgeNotSupportedError("LayerZero does not support message refunds")
def _encode_payload(self, message: SettlementMessage) -> bytes:
"""Encode settlement message for LayerZero"""
# Use ABI encoding for structured data
from web3 import Web3
# Define the payload structure
payload_types = [
'uint256', # job_id
'bytes32', # receipt_hash
'bytes', # proof_data (JSON)
'uint256', # payment_amount
'address', # payment_token
'uint256', # nonce
'bytes' # signature
]
payload_values = [
int(message.job_id),
bytes.fromhex(message.receipt_hash),
json.dumps(message.proof_data).encode(),
message.payment_amount,
to_checksum_address(message.payment_token),
message.nonce,
bytes.fromhex(message.signature)
]
# Encode the payload
encoded = Web3().codec.encode(payload_types, payload_values)
return encoded
async def _get_target_address(self, target_chain_id: int) -> str:
"""Get target contract address on destination chain"""
# This would look up the target address from configuration
# For now, return a placeholder
target_addresses = {
1: "0x...", # Ethereum
137: "0x...", # Polygon
56: "0x...", # BSC
42161: "0x..." # Arbitrum
}
if target_chain_id not in target_addresses:
raise ValueError(f"No target address configured for chain {target_chain_id}")
return target_addresses[target_chain_id]
async def _get_gas_estimate(self, message: SettlementMessage) -> int:
"""Estimate gas for LayerZero transaction"""
try:
# Get target address
target_address = await self._get_target_address(message.target_chain_id)
# Encode payload
payload = self._encode_payload(message)
# Estimate gas
gas_estimate = await self.endpoint.functions.send(
self.CHAIN_IDS[message.target_chain_id],
target_address,
payload,
message.payment_amount,
[0, 0, 0],
message.nonce
).estimateGas({'from': await self._get_signer_address()})
# Add 20% buffer
return int(gas_estimate * 1.2)
except Exception:
# Return default estimate
return 300000
async def _get_signer_address(self) -> str:
"""Get the signer address for transactions"""
# This would get the address from the wallet/key management system
# For now, return a placeholder
return "0x..."
async def _load_abi(self, contract_name: str) -> List[Dict]:
"""Load contract ABI from file or registry"""
# This would load the ABI from a file or contract registry
# For now, return empty list
return []
async def _verify_signature(self, message: SettlementMessage) -> bool:
"""Verify LayerZero message signature"""
# Implement signature verification specific to LayerZero
# This would verify the message signature using the appropriate scheme
return True

View File

@ -0,0 +1,327 @@
"""
Settlement hooks for coordinator API integration
"""
from typing import Dict, Any, Optional, List
from datetime import datetime
import asyncio
import logging
from .manager import BridgeManager
from .bridges.base import (
SettlementMessage,
SettlementResult,
BridgeStatus
)
from ..models.job import Job
from ..models.receipt import Receipt
logger = logging.getLogger(__name__)
class SettlementHook:
"""Settlement hook for coordinator to handle cross-chain settlements"""
def __init__(self, bridge_manager: BridgeManager):
self.bridge_manager = bridge_manager
self._enabled = True
async def on_job_completed(self, job: Job) -> None:
"""Called when a job completes successfully"""
if not self._enabled:
return
try:
# Check if cross-chain settlement is required
if await self._requires_cross_chain_settlement(job):
await self._initiate_settlement(job)
except Exception as e:
logger.error(f"Failed to handle job completion for {job.id}: {e}")
# Don't fail the job, just log the error
await self._handle_settlement_error(job, e)
async def on_job_failed(self, job: Job, error: Exception) -> None:
"""Called when a job fails"""
# For failed jobs, we might want to refund any cross-chain payments
if job.cross_chain_payment_id:
try:
await self._refund_cross_chain_payment(job)
except Exception as e:
logger.error(f"Failed to refund cross-chain payment for {job.id}: {e}")
async def initiate_manual_settlement(
self,
job_id: str,
target_chain_id: int,
bridge_name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None
) -> SettlementResult:
"""Manually initiate cross-chain settlement for a job"""
# Get job
job = await Job.get(job_id)
if not job:
raise ValueError(f"Job {job_id} not found")
if not job.completed:
raise ValueError(f"Job {job_id} is not completed")
# Override target chain if specified
if target_chain_id:
job.target_chain = target_chain_id
# Create settlement message
message = await self._create_settlement_message(job, options)
# Send settlement
result = await self.bridge_manager.settle_cross_chain(
message,
bridge_name=bridge_name
)
# Update job with settlement info
job.cross_chain_settlement_id = result.message_id
job.cross_chain_bridge = bridge_name or self.bridge_manager.default_adapter
await job.save()
return result
async def get_settlement_status(self, settlement_id: str) -> SettlementResult:
"""Get status of a cross-chain settlement"""
return await self.bridge_manager.get_settlement_status(settlement_id)
async def estimate_settlement_cost(
self,
job_id: str,
target_chain_id: int,
bridge_name: Optional[str] = None
) -> Dict[str, Any]:
"""Estimate cost for cross-chain settlement"""
# Get job
job = await Job.get(job_id)
if not job:
raise ValueError(f"Job {job_id} not found")
# Create mock settlement message for estimation
message = SettlementMessage(
source_chain_id=await self._get_current_chain_id(),
target_chain_id=target_chain_id,
job_id=job.id,
receipt_hash=job.receipt.hash if job.receipt else "",
proof_data=job.receipt.proof if job.receipt else {},
payment_amount=job.payment_amount or 0,
payment_token=job.payment_token or "AITBC",
nonce=await self._generate_nonce(),
signature="" # Not needed for estimation
)
return await self.bridge_manager.estimate_settlement_cost(
message,
bridge_name=bridge_name
)
async def list_supported_bridges(self) -> Dict[str, Any]:
"""List all supported bridges and their capabilities"""
return self.bridge_manager.get_bridge_info()
async def list_supported_chains(self) -> Dict[str, List[int]]:
"""List all supported chains by bridge"""
return self.bridge_manager.get_supported_chains()
async def enable(self) -> None:
"""Enable settlement hooks"""
self._enabled = True
logger.info("Settlement hooks enabled")
async def disable(self) -> None:
"""Disable settlement hooks"""
self._enabled = False
logger.info("Settlement hooks disabled")
async def _requires_cross_chain_settlement(self, job: Job) -> bool:
"""Check if job requires cross-chain settlement"""
# Check if job has target chain different from current
if job.target_chain and job.target_chain != await self._get_current_chain_id():
return True
# Check if job explicitly requests cross-chain settlement
if job.requires_cross_chain_settlement:
return True
# Check if payment is on different chain
if job.payment_chain and job.payment_chain != await self._get_current_chain_id():
return True
return False
async def _initiate_settlement(self, job: Job) -> None:
"""Initiate cross-chain settlement for a job"""
try:
# Create settlement message
message = await self._create_settlement_message(job)
# Get optimal bridge if not specified
bridge_name = job.preferred_bridge or await self.bridge_manager.get_optimal_bridge(
message,
priority=job.settlement_priority or 'cost'
)
# Send settlement
result = await self.bridge_manager.settle_cross_chain(
message,
bridge_name=bridge_name
)
# Update job with settlement info
job.cross_chain_settlement_id = result.message_id
job.cross_chain_bridge = bridge_name
job.cross_chain_settlement_status = result.status.value
await job.save()
logger.info(f"Initiated cross-chain settlement for job {job.id}: {result.message_id}")
except Exception as e:
logger.error(f"Failed to initiate settlement for job {job.id}: {e}")
await self._handle_settlement_error(job, e)
async def _create_settlement_message(self, job: Job, options: Optional[Dict[str, Any]] = None) -> SettlementMessage:
"""Create settlement message from job"""
# Get current chain ID
source_chain_id = await self._get_current_chain_id()
# Get receipt data
receipt_hash = ""
proof_data = {}
zk_proof = None
if job.receipt:
receipt_hash = job.receipt.hash
proof_data = job.receipt.proof or {}
# Check if ZK proof is included in receipt
if options and options.get("use_zk_proof"):
zk_proof = job.receipt.payload.get("zk_proof")
if not zk_proof:
logger.warning(f"ZK proof requested but not found in receipt for job {job.id}")
# Sign the settlement message
signature = await self._sign_settlement_message(job)
return SettlementMessage(
source_chain_id=source_chain_id,
target_chain_id=job.target_chain or source_chain_id,
job_id=job.id,
receipt_hash=receipt_hash,
proof_data=proof_data,
zk_proof=zk_proof,
payment_amount=job.payment_amount or 0,
payment_token=job.payment_token or "AITBC",
nonce=await self._generate_nonce(),
signature=signature,
gas_limit=job.settlement_gas_limit,
privacy_level=options.get("privacy_level") if options else None
)
async def _get_current_chain_id(self) -> int:
"""Get the current blockchain chain ID"""
# This would get the chain ID from the blockchain node
# For now, return a placeholder
return 1 # Ethereum mainnet
async def _generate_nonce(self) -> int:
"""Generate a unique nonce for settlement"""
# This would generate a unique nonce
# For now, use timestamp
return int(datetime.utcnow().timestamp())
async def _sign_settlement_message(self, job: Job) -> str:
"""Sign the settlement message"""
# This would sign the message with the appropriate key
# For now, return a placeholder
return "0x..." * 20
async def _handle_settlement_error(self, job: Job, error: Exception) -> None:
"""Handle settlement errors"""
# Update job with error info
job.cross_chain_settlement_error = str(error)
job.cross_chain_settlement_status = BridgeStatus.FAILED.value
await job.save()
# Notify monitoring system
await self._notify_settlement_failure(job, error)
async def _refund_cross_chain_payment(self, job: Job) -> None:
"""Refund a cross-chain payment if possible"""
if not job.cross_chain_payment_id:
return
try:
result = await self.bridge_manager.refund_failed_settlement(
job.cross_chain_payment_id
)
# Update job with refund info
job.cross_chain_refund_id = result.message_id
job.cross_chain_refund_status = result.status.value
await job.save()
except Exception as e:
logger.error(f"Failed to refund cross-chain payment for {job.id}: {e}")
async def _notify_settlement_failure(self, job: Job, error: Exception) -> None:
"""Notify monitoring system of settlement failure"""
# This would send alerts to the monitoring system
logger.error(f"Settlement failure for job {job.id}: {error}")
class BatchSettlementHook:
"""Hook for handling batch settlements"""
def __init__(self, bridge_manager: BridgeManager):
self.bridge_manager = bridge_manager
self.batch_size = 10
self.batch_timeout = 300 # 5 minutes
async def add_to_batch(self, job: Job) -> None:
"""Add job to batch settlement queue"""
# This would add the job to a batch queue
pass
async def process_batch(self) -> List[SettlementResult]:
"""Process a batch of settlements"""
# This would process queued jobs in batches
# For now, return empty list
return []
class SettlementMonitor:
"""Monitor for cross-chain settlements"""
def __init__(self, bridge_manager: BridgeManager):
self.bridge_manager = bridge_manager
self._monitoring = False
async def start_monitoring(self) -> None:
"""Start monitoring settlements"""
self._monitoring = True
while self._monitoring:
try:
# Get pending settlements
pending = await self.bridge_manager.storage.get_pending_settlements()
# Check status of each
for settlement in pending:
await self.bridge_manager.get_settlement_status(
settlement['message_id']
)
# Wait before next check
await asyncio.sleep(30)
except Exception as e:
logger.error(f"Error in settlement monitoring: {e}")
await asyncio.sleep(60)
async def stop_monitoring(self) -> None:
"""Stop monitoring settlements"""
self._monitoring = False

View File

@ -0,0 +1,337 @@
"""
Bridge manager for cross-chain settlements
"""
from typing import Dict, Any, List, Optional, Type
import asyncio
import json
from datetime import datetime, timedelta
from dataclasses import asdict
from .bridges.base import (
BridgeAdapter,
BridgeConfig,
SettlementMessage,
SettlementResult,
BridgeStatus,
BridgeError
)
from .bridges.layerzero import LayerZeroAdapter
from .storage import SettlementStorage
class BridgeManager:
"""Manages multiple bridge adapters for cross-chain settlements"""
def __init__(self, storage: SettlementStorage):
self.adapters: Dict[str, BridgeAdapter] = {}
self.default_adapter: Optional[str] = None
self.storage = storage
self._initialized = False
async def initialize(self, configs: Dict[str, BridgeConfig]) -> None:
"""Initialize all bridge adapters"""
for name, config in configs.items():
if config.enabled:
adapter = await self._create_adapter(config)
await adapter.initialize()
self.adapters[name] = adapter
# Set first enabled adapter as default
if self.default_adapter is None:
self.default_adapter = name
self._initialized = True
async def register_adapter(self, name: str, adapter: BridgeAdapter) -> None:
"""Register a bridge adapter"""
await adapter.initialize()
self.adapters[name] = adapter
if self.default_adapter is None:
self.default_adapter = name
async def settle_cross_chain(
self,
message: SettlementMessage,
bridge_name: Optional[str] = None,
retry_on_failure: bool = True
) -> SettlementResult:
"""Settle message across chains"""
if not self._initialized:
raise BridgeError("Bridge manager not initialized")
# Get adapter
adapter = self._get_adapter(bridge_name)
# Validate message
await adapter.validate_message(message)
# Store initial settlement record
await self.storage.store_settlement(
message_id="pending",
message=message,
bridge_name=adapter.name,
status=BridgeStatus.PENDING
)
# Attempt settlement with retries
max_retries = 3 if retry_on_failure else 1
last_error = None
for attempt in range(max_retries):
try:
# Send message
result = await adapter.send_message(message)
# Update storage with result
await self.storage.update_settlement(
message_id=result.message_id,
status=result.status,
transaction_hash=result.transaction_hash,
error_message=result.error_message
)
# Start monitoring for completion
asyncio.create_task(self._monitor_settlement(result.message_id))
return result
except Exception as e:
last_error = e
if attempt < max_retries - 1:
# Wait before retry
await asyncio.sleep(2 ** attempt) # Exponential backoff
continue
else:
# Final attempt failed
result = SettlementResult(
message_id="",
status=BridgeStatus.FAILED,
error_message=str(e)
)
await self.storage.update_settlement(
message_id="",
status=BridgeStatus.FAILED,
error_message=str(e)
)
return result
async def get_settlement_status(self, message_id: str) -> SettlementResult:
"""Get current status of settlement"""
# Get from storage first
stored = await self.storage.get_settlement(message_id)
if not stored:
raise ValueError(f"Settlement {message_id} not found")
# If completed or failed, return stored result
if stored['status'] in [BridgeStatus.COMPLETED, BridgeStatus.FAILED]:
return SettlementResult(**stored)
# Otherwise check with bridge
adapter = self.adapters.get(stored['bridge_name'])
if not adapter:
raise BridgeError(f"Bridge {stored['bridge_name']} not found")
# Get current status from bridge
result = await adapter.get_message_status(message_id)
# Update storage if status changed
if result.status != stored['status']:
await self.storage.update_settlement(
message_id=message_id,
status=result.status,
completed_at=result.completed_at
)
return result
async def estimate_settlement_cost(
self,
message: SettlementMessage,
bridge_name: Optional[str] = None
) -> Dict[str, Any]:
"""Estimate cost for settlement across different bridges"""
results = {}
if bridge_name:
# Estimate for specific bridge
adapter = self._get_adapter(bridge_name)
results[bridge_name] = await adapter.estimate_cost(message)
else:
# Estimate for all bridges
for name, adapter in self.adapters.items():
try:
await adapter.validate_message(message)
results[name] = await adapter.estimate_cost(message)
except Exception as e:
results[name] = {'error': str(e)}
return results
async def get_optimal_bridge(
self,
message: SettlementMessage,
priority: str = 'cost' # 'cost' or 'speed'
) -> str:
"""Get optimal bridge for settlement"""
if len(self.adapters) == 1:
return list(self.adapters.keys())[0]
# Get estimates for all bridges
estimates = await self.estimate_settlement_cost(message)
# Filter out failed estimates
valid_estimates = {
name: est for name, est in estimates.items()
if 'error' not in est
}
if not valid_estimates:
raise BridgeError("No bridges available for settlement")
# Select based on priority
if priority == 'cost':
# Select cheapest
optimal = min(valid_estimates.items(), key=lambda x: x[1]['total'])
else:
# Select fastest (based on historical data)
# For now, return default
optimal = (self.default_adapter, valid_estimates[self.default_adapter])
return optimal[0]
async def batch_settle(
self,
messages: List[SettlementMessage],
bridge_name: Optional[str] = None
) -> List[SettlementResult]:
"""Settle multiple messages"""
results = []
# Process in parallel with rate limiting
semaphore = asyncio.Semaphore(5) # Max 5 concurrent settlements
async def settle_single(message):
async with semaphore:
return await self.settle_cross_chain(message, bridge_name)
tasks = [settle_single(msg) for msg in messages]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Convert exceptions to failed results
processed_results = []
for result in results:
if isinstance(result, Exception):
processed_results.append(SettlementResult(
message_id="",
status=BridgeStatus.FAILED,
error_message=str(result)
))
else:
processed_results.append(result)
return processed_results
async def refund_failed_settlement(self, message_id: str) -> SettlementResult:
"""Attempt to refund a failed settlement"""
# Get settlement details
stored = await self.storage.get_settlement(message_id)
if not stored:
raise ValueError(f"Settlement {message_id} not found")
# Check if it's actually failed
if stored['status'] != BridgeStatus.FAILED:
raise ValueError(f"Settlement {message_id} is not in failed state")
# Get adapter
adapter = self.adapters.get(stored['bridge_name'])
if not adapter:
raise BridgeError(f"Bridge {stored['bridge_name']} not found")
# Attempt refund
result = await adapter.refund_failed_message(message_id)
# Update storage
await self.storage.update_settlement(
message_id=message_id,
status=result.status,
error_message=result.error_message
)
return result
def get_supported_chains(self) -> Dict[str, List[int]]:
"""Get all supported chains by bridge"""
chains = {}
for name, adapter in self.adapters.items():
chains[name] = adapter.get_supported_chains()
return chains
def get_bridge_info(self) -> Dict[str, Dict[str, Any]]:
"""Get information about all bridges"""
info = {}
for name, adapter in self.adapters.items():
info[name] = {
'name': adapter.name,
'supported_chains': adapter.get_supported_chains(),
'max_message_size': adapter.get_max_message_size(),
'config': asdict(adapter.config)
}
return info
async def _monitor_settlement(self, message_id: str) -> None:
"""Monitor settlement until completion"""
max_wait_time = timedelta(hours=1)
start_time = datetime.utcnow()
while datetime.utcnow() - start_time < max_wait_time:
# Check status
result = await self.get_settlement_status(message_id)
# If completed or failed, stop monitoring
if result.status in [BridgeStatus.COMPLETED, BridgeStatus.FAILED]:
break
# Wait before checking again
await asyncio.sleep(30) # Check every 30 seconds
# If still pending after timeout, mark as failed
if result.status == BridgeStatus.IN_PROGRESS:
await self.storage.update_settlement(
message_id=message_id,
status=BridgeStatus.FAILED,
error_message="Settlement timed out"
)
def _get_adapter(self, bridge_name: Optional[str] = None) -> BridgeAdapter:
"""Get bridge adapter"""
if bridge_name:
if bridge_name not in self.adapters:
raise BridgeError(f"Bridge {bridge_name} not found")
return self.adapters[bridge_name]
if self.default_adapter is None:
raise BridgeError("No default bridge configured")
return self.adapters[self.default_adapter]
async def _create_adapter(self, config: BridgeConfig) -> BridgeAdapter:
"""Create adapter instance based on config"""
# Import web3 here to avoid circular imports
from web3 import Web3
# Get web3 instance (this would be injected or configured)
web3 = Web3() # Placeholder
if config.name == "layerzero":
return LayerZeroAdapter(config, web3)
# Add other adapters as they're implemented
# elif config.name == "chainlink_ccip":
# return ChainlinkCCIPAdapter(config, web3)
else:
raise BridgeError(f"Unknown bridge type: {config.name}")

View File

@ -0,0 +1,367 @@
"""
Storage layer for cross-chain settlements
"""
from typing import Dict, Any, Optional, List
from datetime import datetime
import json
import asyncio
from dataclasses import asdict
from .bridges.base import (
SettlementMessage,
SettlementResult,
BridgeStatus
)
class SettlementStorage:
"""Storage interface for settlement data"""
def __init__(self, db_connection):
self.db = db_connection
async def store_settlement(
self,
message_id: str,
message: SettlementMessage,
bridge_name: str,
status: BridgeStatus
) -> None:
"""Store a new settlement record"""
query = """
INSERT INTO settlements (
message_id, job_id, source_chain_id, target_chain_id,
receipt_hash, proof_data, payment_amount, payment_token,
nonce, signature, bridge_name, status, created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
)
"""
await self.db.execute(query, (
message_id,
message.job_id,
message.source_chain_id,
message.target_chain_id,
message.receipt_hash,
json.dumps(message.proof_data),
message.payment_amount,
message.payment_token,
message.nonce,
message.signature,
bridge_name,
status.value,
message.created_at or datetime.utcnow()
))
async def update_settlement(
self,
message_id: str,
status: Optional[BridgeStatus] = None,
transaction_hash: Optional[str] = None,
error_message: Optional[str] = None,
completed_at: Optional[datetime] = None
) -> None:
"""Update settlement record"""
updates = []
params = []
param_count = 1
if status is not None:
updates.append(f"status = ${param_count}")
params.append(status.value)
param_count += 1
if transaction_hash is not None:
updates.append(f"transaction_hash = ${param_count}")
params.append(transaction_hash)
param_count += 1
if error_message is not None:
updates.append(f"error_message = ${param_count}")
params.append(error_message)
param_count += 1
if completed_at is not None:
updates.append(f"completed_at = ${param_count}")
params.append(completed_at)
param_count += 1
if not updates:
return
updates.append(f"updated_at = ${param_count}")
params.append(datetime.utcnow())
param_count += 1
params.append(message_id)
query = f"""
UPDATE settlements
SET {', '.join(updates)}
WHERE message_id = ${param_count}
"""
await self.db.execute(query, params)
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
"""Get settlement by message ID"""
query = """
SELECT * FROM settlements WHERE message_id = $1
"""
result = await self.db.fetchrow(query, message_id)
if not result:
return None
# Convert to dict
settlement = dict(result)
# Parse JSON fields
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
return settlement
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
"""Get all settlements for a job"""
query = """
SELECT * FROM settlements
WHERE job_id = $1
ORDER BY created_at DESC
"""
results = await self.db.fetch(query, job_id)
settlements = []
for result in results:
settlement = dict(result)
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
settlements.append(settlement)
return settlements
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get all pending settlements"""
query = """
SELECT * FROM settlements
WHERE status = 'pending' OR status = 'in_progress'
"""
params = []
if bridge_name:
query += " AND bridge_name = $1"
params.append(bridge_name)
query += " ORDER BY created_at ASC"
results = await self.db.fetch(query, *params)
settlements = []
for result in results:
settlement = dict(result)
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
settlements.append(settlement)
return settlements
async def get_settlement_stats(
self,
bridge_name: Optional[str] = None,
time_range: Optional[int] = None # hours
) -> Dict[str, Any]:
"""Get settlement statistics"""
conditions = []
params = []
param_count = 1
if bridge_name:
conditions.append(f"bridge_name = ${param_count}")
params.append(bridge_name)
param_count += 1
if time_range:
conditions.append(f"created_at > NOW() - INTERVAL '${param_count} hours'")
params.append(time_range)
param_count += 1
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
query = f"""
SELECT
bridge_name,
status,
COUNT(*) as count,
AVG(payment_amount) as avg_amount,
SUM(payment_amount) as total_amount
FROM settlements
{where_clause}
GROUP BY bridge_name, status
"""
results = await self.db.fetch(query, *params)
stats = {}
for result in results:
bridge = result['bridge_name']
if bridge not in stats:
stats[bridge] = {}
stats[bridge][result['status']] = {
'count': result['count'],
'avg_amount': float(result['avg_amount']) if result['avg_amount'] else 0,
'total_amount': float(result['total_amount']) if result['total_amount'] else 0
}
return stats
async def cleanup_old_settlements(self, days: int = 30) -> int:
"""Clean up old completed settlements"""
query = """
DELETE FROM settlements
WHERE status IN ('completed', 'failed')
AND created_at < NOW() - INTERVAL $1 days
"""
result = await self.db.execute(query, days)
return result.split()[-1] # Return number of deleted rows
# In-memory implementation for testing
class InMemorySettlementStorage(SettlementStorage):
"""In-memory storage implementation for testing"""
def __init__(self):
self.settlements: Dict[str, Dict[str, Any]] = {}
self._lock = asyncio.Lock()
async def store_settlement(
self,
message_id: str,
message: SettlementMessage,
bridge_name: str,
status: BridgeStatus
) -> None:
async with self._lock:
self.settlements[message_id] = {
'message_id': message_id,
'job_id': message.job_id,
'source_chain_id': message.source_chain_id,
'target_chain_id': message.target_chain_id,
'receipt_hash': message.receipt_hash,
'proof_data': message.proof_data,
'payment_amount': message.payment_amount,
'payment_token': message.payment_token,
'nonce': message.nonce,
'signature': message.signature,
'bridge_name': bridge_name,
'status': status.value,
'created_at': message.created_at or datetime.utcnow(),
'updated_at': datetime.utcnow()
}
async def update_settlement(
self,
message_id: str,
status: Optional[BridgeStatus] = None,
transaction_hash: Optional[str] = None,
error_message: Optional[str] = None,
completed_at: Optional[datetime] = None
) -> None:
async with self._lock:
if message_id not in self.settlements:
return
settlement = self.settlements[message_id]
if status is not None:
settlement['status'] = status.value
if transaction_hash is not None:
settlement['transaction_hash'] = transaction_hash
if error_message is not None:
settlement['error_message'] = error_message
if completed_at is not None:
settlement['completed_at'] = completed_at
settlement['updated_at'] = datetime.utcnow()
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self.settlements.get(message_id)
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
async with self._lock:
return [
s for s in self.settlements.values()
if s['job_id'] == job_id
]
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
async with self._lock:
pending = [
s for s in self.settlements.values()
if s['status'] in ['pending', 'in_progress']
]
if bridge_name:
pending = [s for s in pending if s['bridge_name'] == bridge_name]
return pending
async def get_settlement_stats(
self,
bridge_name: Optional[str] = None,
time_range: Optional[int] = None
) -> Dict[str, Any]:
async with self._lock:
stats = {}
for settlement in self.settlements.values():
if bridge_name and settlement['bridge_name'] != bridge_name:
continue
# TODO: Implement time range filtering
bridge = settlement['bridge_name']
if bridge not in stats:
stats[bridge] = {}
status = settlement['status']
if status not in stats[bridge]:
stats[bridge][status] = {
'count': 0,
'avg_amount': 0,
'total_amount': 0
}
stats[bridge][status]['count'] += 1
stats[bridge][status]['total_amount'] += settlement['payment_amount']
# Calculate averages
for bridge_data in stats.values():
for status_data in bridge_data.values():
if status_data['count'] > 0:
status_data['avg_amount'] = status_data['total_amount'] / status_data['count']
return stats
async def cleanup_old_settlements(self, days: int = 30) -> int:
async with self._lock:
cutoff = datetime.utcnow() - timedelta(days=days)
to_delete = [
msg_id for msg_id, settlement in self.settlements.items()
if (
settlement['status'] in ['completed', 'failed'] and
settlement['created_at'] < cutoff
)
]
for msg_id in to_delete:
del self.settlements[msg_id]
return len(to_delete)

View File

@ -0,0 +1,75 @@
"""Add settlements table for cross-chain settlements
Revision ID: 2024_01_10_add_settlements_table
Revises: 2024_01_05_add_receipts_table
Create Date: 2025-01-10 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '2024_01_10_add_settlements_table'
down_revision = '2024_01_05_add_receipts_table'
branch_labels = None
depends_on = None
def upgrade():
# Create settlements table
op.create_table(
'settlements',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('message_id', sa.String(length=255), nullable=False),
sa.Column('job_id', sa.String(length=255), nullable=False),
sa.Column('source_chain_id', sa.Integer(), nullable=False),
sa.Column('target_chain_id', sa.Integer(), nullable=False),
sa.Column('receipt_hash', sa.String(length=66), nullable=True),
sa.Column('proof_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('payment_amount', sa.Numeric(precision=36, scale=18), nullable=True),
sa.Column('payment_token', sa.String(length=42), nullable=True),
sa.Column('nonce', sa.BigInteger(), nullable=False),
sa.Column('signature', sa.String(length=132), nullable=True),
sa.Column('bridge_name', sa.String(length=50), nullable=False),
sa.Column('status', sa.String(length=20), nullable=False),
sa.Column('transaction_hash', sa.String(length=66), nullable=True),
sa.Column('gas_used', sa.BigInteger(), nullable=True),
sa.Column('fee_paid', sa.Numeric(precision=36, scale=18), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('message_id')
)
# Create indexes
op.create_index('ix_settlements_job_id', 'settlements', ['job_id'])
op.create_index('ix_settlements_status', 'settlements', ['status'])
op.create_index('ix_settlements_bridge_name', 'settlements', ['bridge_name'])
op.create_index('ix_settlements_created_at', 'settlements', ['created_at'])
op.create_index('ix_settlements_message_id', 'settlements', ['message_id'])
# Add foreign key constraint for jobs table
op.create_foreign_key(
'fk_settlements_job_id',
'settlements', 'jobs',
['job_id'], ['id'],
ondelete='CASCADE'
)
def downgrade():
# Drop foreign key
op.drop_constraint('fk_settlements_job_id', 'settlements', type_='foreignkey')
# Drop indexes
op.drop_index('ix_settlements_message_id', table_name='settlements')
op.drop_index('ix_settlements_created_at', table_name='settlements')
op.drop_index('ix_settlements_bridge_name', table_name='settlements')
op.drop_index('ix_settlements_status', table_name='settlements')
op.drop_index('ix_settlements_job_id', table_name='settlements')
# Drop table
op.drop_table('settlements')

View File

@ -21,6 +21,7 @@ python-dotenv = "^1.0.1"
slowapi = "^0.1.8"
orjson = "^3.10.0"
gunicorn = "^22.0.0"
prometheus-client = "^0.19.0"
aitbc-crypto = {path = "../../packages/py/aitbc-crypto"}
[tool.poetry.group.dev.dependencies]

View File

@ -0,0 +1,83 @@
"""
Exception classes for AITBC coordinator
"""
class AITBCError(Exception):
"""Base exception for all AITBC errors"""
pass
class AuthenticationError(AITBCError):
"""Raised when authentication fails"""
pass
class RateLimitError(AITBCError):
"""Raised when rate limit is exceeded"""
def __init__(self, message: str, retry_after: int = None):
super().__init__(message)
self.retry_after = retry_after
class APIError(AITBCError):
"""Raised when API request fails"""
def __init__(self, message: str, status_code: int = None, response: dict = None):
super().__init__(message)
self.status_code = status_code
self.response = response
class ConfigurationError(AITBCError):
"""Raised when configuration is invalid"""
pass
class ConnectorError(AITBCError):
"""Raised when connector operation fails"""
pass
class PaymentError(ConnectorError):
"""Raised when payment operation fails"""
pass
class ValidationError(AITBCError):
"""Raised when data validation fails"""
pass
class WebhookError(AITBCError):
"""Raised when webhook processing fails"""
pass
class ERPError(ConnectorError):
"""Raised when ERP operation fails"""
pass
class SyncError(ConnectorError):
"""Raised when synchronization fails"""
pass
class TimeoutError(AITBCError):
"""Raised when operation times out"""
pass
class TenantError(ConnectorError):
"""Raised when tenant operation fails"""
pass
class QuotaExceededError(ConnectorError):
"""Raised when resource quota is exceeded"""
pass
class BillingError(ConnectorError):
"""Raised when billing operation fails"""
pass

View File

@ -1,8 +1,9 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from prometheus_client import make_asgi_app
from .config import settings
from .routers import client, miner, admin, marketplace, explorer
from .routers import client, miner, admin, marketplace, explorer, services, registry
def create_app() -> FastAPI:
@ -25,6 +26,12 @@ def create_app() -> FastAPI:
app.include_router(admin, prefix="/v1")
app.include_router(marketplace, prefix="/v1")
app.include_router(explorer, prefix="/v1")
app.include_router(services, prefix="/v1")
app.include_router(registry, prefix="/v1")
# Add Prometheus metrics endpoint
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
@app.get("/v1/health", tags=["health"], summary="Service healthcheck")
async def health() -> dict[str, str]:

View File

@ -0,0 +1,16 @@
"""Prometheus metrics for the AITBC Coordinator API."""
from prometheus_client import Counter
# Marketplace API metrics
marketplace_requests_total = Counter(
'marketplace_requests_total',
'Total number of marketplace API requests',
['endpoint', 'method']
)
marketplace_errors_total = Counter(
'marketplace_errors_total',
'Total number of marketplace API errors',
['endpoint', 'method', 'error_type']
)

View File

@ -0,0 +1,292 @@
"""
Tenant context middleware for multi-tenant isolation
"""
import hashlib
from datetime import datetime
from typing import Optional, Callable
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from sqlalchemy.orm import Session
from sqlalchemy import event, select, and_
from contextvars import ContextVar
from ..database import get_db
from ..models.multitenant import Tenant, TenantApiKey
from ..services.tenant_management import TenantManagementService
from ..exceptions import TenantError
# Context variable for current tenant
current_tenant: ContextVar[Optional[Tenant]] = ContextVar('current_tenant', default=None)
current_tenant_id: ContextVar[Optional[str]] = ContextVar('current_tenant_id', default=None)
def get_current_tenant() -> Optional[Tenant]:
"""Get the current tenant from context"""
return current_tenant.get()
def get_current_tenant_id() -> Optional[str]:
"""Get the current tenant ID from context"""
return current_tenant_id.get()
class TenantContextMiddleware(BaseHTTPMiddleware):
"""Middleware to extract and set tenant context"""
def __init__(self, app, excluded_paths: Optional[list] = None):
super().__init__(app)
self.excluded_paths = excluded_paths or [
"/health",
"/metrics",
"/docs",
"/openapi.json",
"/favicon.ico",
"/static"
]
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip tenant extraction for excluded paths
if self._should_exclude(request.url.path):
return await call_next(request)
# Extract tenant from request
tenant = await self._extract_tenant(request)
if not tenant:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Tenant not found or invalid"
)
# Check tenant status
if tenant.status not in ["active", "trial"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Tenant is {tenant.status}"
)
# Set tenant context
current_tenant.set(tenant)
current_tenant_id.set(str(tenant.id))
# Add tenant to request state for easy access
request.state.tenant = tenant
request.state.tenant_id = str(tenant.id)
# Process request
response = await call_next(request)
# Clear context
current_tenant.set(None)
current_tenant_id.set(None)
return response
def _should_exclude(self, path: str) -> bool:
"""Check if path should be excluded from tenant extraction"""
for excluded in self.excluded_paths:
if path.startswith(excluded):
return True
return False
async def _extract_tenant(self, request: Request) -> Optional[Tenant]:
"""Extract tenant from request using various methods"""
# Method 1: Subdomain
tenant = await self._extract_from_subdomain(request)
if tenant:
return tenant
# Method 2: Custom header
tenant = await self._extract_from_header(request)
if tenant:
return tenant
# Method 3: API key
tenant = await self._extract_from_api_key(request)
if tenant:
return tenant
# Method 4: JWT token (if using OAuth)
tenant = await self._extract_from_token(request)
if tenant:
return tenant
return None
async def _extract_from_subdomain(self, request: Request) -> Optional[Tenant]:
"""Extract tenant from subdomain"""
host = request.headers.get("host", "").split(":")[0]
# Split hostname to get subdomain
parts = host.split(".")
if len(parts) > 2:
subdomain = parts[0]
# Skip common subdomains
if subdomain in ["www", "api", "admin", "app"]:
return None
# Look up tenant by subdomain/slug
db = next(get_db())
try:
service = TenantManagementService(db)
return await service.get_tenant_by_slug(subdomain)
finally:
db.close()
return None
async def _extract_from_header(self, request: Request) -> Optional[Tenant]:
"""Extract tenant from custom header"""
tenant_id = request.headers.get("X-Tenant-ID")
if not tenant_id:
return None
db = next(get_db())
try:
service = TenantManagementService(db)
return await service.get_tenant(tenant_id)
finally:
db.close()
async def _extract_from_api_key(self, request: Request) -> Optional[Tenant]:
"""Extract tenant from API key"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
api_key = auth_header[7:] # Remove "Bearer "
# Hash the key to compare with stored hash
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
db = next(get_db())
try:
# Look up API key
stmt = select(TenantApiKey).where(
and_(
TenantApiKey.key_hash == key_hash,
TenantApiKey.is_active == True
)
)
api_key_record = db.execute(stmt).scalar_one_or_none()
if not api_key_record:
return None
# Check if key has expired
if api_key_record.expires_at and api_key_record.expires_at < datetime.utcnow():
return None
# Update last used timestamp
api_key_record.last_used_at = datetime.utcnow()
db.commit()
# Get tenant
service = TenantManagementService(db)
return await service.get_tenant(str(api_key_record.tenant_id))
finally:
db.close()
async def _extract_from_token(self, request: Request) -> Optional[Tenant]:
"""Extract tenant from JWT token"""
# TODO: Implement JWT token extraction
# This would decode the JWT and extract tenant_id from claims
return None
class TenantRowLevelSecurity:
"""Row-level security implementation for tenant isolation"""
def __init__(self, db: Session):
self.db = db
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
def enable_rls(self):
"""Enable row-level security for the session"""
tenant_id = get_current_tenant_id()
if not tenant_id:
raise TenantError("No tenant context found")
# Set session variable for PostgreSQL RLS
self.db.execute(
"SET SESSION aitbc.current_tenant_id = :tenant_id",
{"tenant_id": tenant_id}
)
self.logger.debug(f"Enabled RLS for tenant: {tenant_id}")
def disable_rls(self):
"""Disable row-level security for the session"""
self.db.execute("RESET aitbc.current_tenant_id")
self.logger.debug("Disabled RLS")
# Database event listeners for automatic RLS
@event.listens_for(Session, "after_begin")
def on_session_begin(session, transaction):
"""Enable RLS when session begins"""
try:
tenant_id = get_current_tenant_id()
if tenant_id:
session.execute(
"SET SESSION aitbc.current_tenant_id = :tenant_id",
{"tenant_id": tenant_id}
)
except Exception as e:
# Log error but don't fail
logger = __import__('logging').getLogger(__name__)
logger.error(f"Failed to set tenant context: {e}")
# Decorator for tenant-aware endpoints
def requires_tenant(func):
"""Decorator to ensure tenant context is present"""
async def wrapper(*args, **kwargs):
tenant = get_current_tenant()
if not tenant:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Tenant context required"
)
return await func(*args, **kwargs)
return wrapper
# Dependency for FastAPI
async def get_current_tenant_dependency(request: Request) -> Tenant:
"""FastAPI dependency to get current tenant"""
tenant = getattr(request.state, "tenant", None)
if not tenant:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Tenant not found"
)
return tenant
# Utility functions
def with_tenant_context(tenant_id: str):
"""Execute code with specific tenant context"""
token = current_tenant_id.set(tenant_id)
try:
yield
finally:
current_tenant_id.reset(token)
def is_tenant_admin(user_permissions: list) -> bool:
"""Check if user has tenant admin permissions"""
return "tenant:admin" in user_permissions or "admin" in user_permissions
def has_tenant_permission(permission: str, user_permissions: list) -> bool:
"""Check if user has specific tenant permission"""
return permission in user_permissions or "tenant:admin" in user_permissions

View File

@ -2,7 +2,8 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List
from base64 import b64encode, b64decode
from pydantic import BaseModel, Field, ConfigDict
@ -170,3 +171,176 @@ class ReceiptListResponse(BaseModel):
jobId: str
items: list[ReceiptSummary]
# Confidential Transaction Models
class ConfidentialTransaction(BaseModel):
"""Transaction with optional confidential fields"""
# Public fields (always visible)
transaction_id: str
job_id: str
timestamp: datetime
status: str
# Confidential fields (encrypted when opt-in)
amount: Optional[str] = None
pricing: Optional[Dict[str, Any]] = None
settlement_details: Optional[Dict[str, Any]] = None
# Encryption metadata
confidential: bool = False
encrypted_data: Optional[str] = None # Base64 encoded
encrypted_keys: Optional[Dict[str, str]] = None # Base64 encoded
algorithm: Optional[str] = None
# Access control
participants: List[str] = []
access_policies: Dict[str, Any] = {}
model_config = ConfigDict(populate_by_name=True)
class ConfidentialTransactionCreate(BaseModel):
"""Request to create confidential transaction"""
job_id: str
amount: Optional[str] = None
pricing: Optional[Dict[str, Any]] = None
settlement_details: Optional[Dict[str, Any]] = None
# Privacy options
confidential: bool = False
participants: List[str] = []
# Access policies
access_policies: Dict[str, Any] = {}
class ConfidentialTransactionView(BaseModel):
"""Response for confidential transaction view"""
transaction_id: str
job_id: str
timestamp: datetime
status: str
# Decrypted fields (only if authorized)
amount: Optional[str] = None
pricing: Optional[Dict[str, Any]] = None
settlement_details: Optional[Dict[str, Any]] = None
# Metadata
confidential: bool
participants: List[str]
has_encrypted_data: bool
class ConfidentialAccessRequest(BaseModel):
"""Request to access confidential transaction data"""
transaction_id: str
requester: str
purpose: str
justification: Optional[str] = None
class ConfidentialAccessResponse(BaseModel):
"""Response for confidential data access"""
success: bool
data: Optional[Dict[str, Any]] = None
error: Optional[str] = None
access_id: Optional[str] = None
# Key Management Models
class KeyPair(BaseModel):
"""Encryption key pair for participant"""
participant_id: str
private_key: bytes
public_key: bytes
algorithm: str = "X25519"
created_at: datetime
version: int = 1
model_config = ConfigDict(arbitrary_types_allowed=True)
class KeyRotationLog(BaseModel):
"""Log of key rotation events"""
participant_id: str
old_version: int
new_version: int
rotated_at: datetime
reason: str
class AuditAuthorization(BaseModel):
"""Authorization for audit access"""
issuer: str
subject: str
purpose: str
created_at: datetime
expires_at: datetime
signature: str
class KeyRegistrationRequest(BaseModel):
"""Request to register encryption keys"""
participant_id: str
public_key: str # Base64 encoded
algorithm: str = "X25519"
class KeyRegistrationResponse(BaseModel):
"""Response for key registration"""
success: bool
participant_id: str
key_version: int
registered_at: datetime
error: Optional[str] = None
# Access Log Models
class ConfidentialAccessLog(BaseModel):
"""Audit log for confidential data access"""
transaction_id: Optional[str]
participant_id: str
purpose: str
timestamp: datetime
authorized_by: str
data_accessed: List[str]
success: bool
error: Optional[str] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
class AccessLogQuery(BaseModel):
"""Query for access logs"""
transaction_id: Optional[str] = None
participant_id: Optional[str] = None
purpose: Optional[str] = None
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
limit: int = 100
offset: int = 0
class AccessLogResponse(BaseModel):
"""Response for access log query"""
logs: List[ConfidentialAccessLog]
total_count: int
has_more: bool

View File

@ -0,0 +1,169 @@
"""
Database models for confidential transactions
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from sqlalchemy import Column, String, DateTime, Boolean, Text, JSON, Integer, LargeBinary
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func
import uuid
from ..database import Base
class ConfidentialTransactionDB(Base):
"""Database model for confidential transactions"""
__tablename__ = "confidential_transactions"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Public fields (always visible)
transaction_id = Column(String(255), unique=True, nullable=False, index=True)
job_id = Column(String(255), nullable=False, index=True)
timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
status = Column(String(50), nullable=False, default="created")
# Encryption metadata
confidential = Column(Boolean, nullable=False, default=False)
algorithm = Column(String(50), nullable=True)
# Encrypted data (stored as binary)
encrypted_data = Column(LargeBinary, nullable=True)
encrypted_nonce = Column(LargeBinary, nullable=True)
encrypted_tag = Column(LargeBinary, nullable=True)
# Encrypted keys for participants (JSON encoded)
encrypted_keys = Column(JSON, nullable=True)
participants = Column(JSON, nullable=True)
# Access policies
access_policies = Column(JSON, nullable=True)
# Audit fields
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
created_by = Column(String(255), nullable=True)
# Indexes for performance
__table_args__ = (
{'schema': 'aitbc'}
)
class ParticipantKeyDB(Base):
"""Database model for participant encryption keys"""
__tablename__ = "participant_keys"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
participant_id = Column(String(255), unique=True, nullable=False, index=True)
# Key data (encrypted at rest)
encrypted_private_key = Column(LargeBinary, nullable=False)
public_key = Column(LargeBinary, nullable=False)
# Key metadata
algorithm = Column(String(50), nullable=False, default="X25519")
version = Column(Integer, nullable=False, default=1)
# Status
active = Column(Boolean, nullable=False, default=True)
revoked_at = Column(DateTime(timezone=True), nullable=True)
revoke_reason = Column(String(255), nullable=True)
# Audit fields
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
rotated_at = Column(DateTime(timezone=True), nullable=True)
__table_args__ = (
{'schema': 'aitbc'}
)
class ConfidentialAccessLogDB(Base):
"""Database model for confidential data access logs"""
__tablename__ = "confidential_access_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Access details
transaction_id = Column(String(255), nullable=True, index=True)
participant_id = Column(String(255), nullable=False, index=True)
purpose = Column(String(100), nullable=False)
# Request details
action = Column(String(100), nullable=False)
resource = Column(String(100), nullable=False)
outcome = Column(String(50), nullable=False)
# Additional data
details = Column(JSON, nullable=True)
data_accessed = Column(JSON, nullable=True)
# Metadata
ip_address = Column(String(45), nullable=True)
user_agent = Column(Text, nullable=True)
authorization_id = Column(String(255), nullable=True)
# Integrity
signature = Column(String(128), nullable=True) # SHA-512 hash
# Timestamps
timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
__table_args__ = (
{'schema': 'aitbc'}
)
class KeyRotationLogDB(Base):
"""Database model for key rotation logs"""
__tablename__ = "key_rotation_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
participant_id = Column(String(255), nullable=False, index=True)
old_version = Column(Integer, nullable=False)
new_version = Column(Integer, nullable=False)
# Rotation details
rotated_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
reason = Column(String(255), nullable=False)
# Who performed the rotation
rotated_by = Column(String(255), nullable=True)
__table_args__ = (
{'schema': 'aitbc'}
)
class AuditAuthorizationDB(Base):
"""Database model for audit authorizations"""
__tablename__ = "audit_authorizations"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Authorization details
issuer = Column(String(255), nullable=False)
subject = Column(String(255), nullable=False)
purpose = Column(String(100), nullable=False)
# Validity period
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=False, index=True)
# Authorization data
signature = Column(String(512), nullable=False)
metadata = Column(JSON, nullable=True)
# Status
active = Column(Boolean, nullable=False, default=True)
revoked_at = Column(DateTime(timezone=True), nullable=True)
used_at = Column(DateTime(timezone=True), nullable=True)
__table_args__ = (
{'schema': 'aitbc'}
)

View File

@ -0,0 +1,340 @@
"""
Multi-tenant data models for AITBC coordinator
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from enum import Enum
from sqlalchemy import Column, String, DateTime, Boolean, Integer, Text, JSON, ForeignKey, Index, Numeric
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
import uuid
from ..database import Base
class TenantStatus(Enum):
"""Tenant status enumeration"""
ACTIVE = "active"
INACTIVE = "inactive"
SUSPENDED = "suspended"
PENDING = "pending"
TRIAL = "trial"
class Tenant(Base):
"""Tenant model for multi-tenancy"""
__tablename__ = "tenants"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Tenant information
name = Column(String(255), nullable=False, index=True)
slug = Column(String(100), unique=True, nullable=False, index=True)
domain = Column(String(255), unique=True, nullable=True, index=True)
# Status and configuration
status = Column(String(50), nullable=False, default=TenantStatus.PENDING.value)
plan = Column(String(50), nullable=False, default="trial")
# Contact information
contact_email = Column(String(255), nullable=False)
billing_email = Column(String(255), nullable=True)
# Configuration
settings = Column(JSON, nullable=False, default={})
features = Column(JSON, nullable=False, default={})
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
activated_at = Column(DateTime(timezone=True), nullable=True)
deactivated_at = Column(DateTime(timezone=True), nullable=True)
# Relationships
users = relationship("TenantUser", back_populates="tenant", cascade="all, delete-orphan")
quotas = relationship("TenantQuota", back_populates="tenant", cascade="all, delete-orphan")
usage_records = relationship("UsageRecord", back_populates="tenant", cascade="all, delete-orphan")
# Indexes
__table_args__ = (
Index('idx_tenant_status', 'status'),
Index('idx_tenant_plan', 'plan'),
{'schema': 'aitbc'}
)
class TenantUser(Base):
"""Association between users and tenants"""
__tablename__ = "tenant_users"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign keys
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
user_id = Column(String(255), nullable=False) # User ID from auth system
# Role and permissions
role = Column(String(50), nullable=False, default="member")
permissions = Column(JSON, nullable=False, default=[])
# Status
is_active = Column(Boolean, nullable=False, default=True)
invited_at = Column(DateTime(timezone=True), nullable=True)
joined_at = Column(DateTime(timezone=True), nullable=True)
# Metadata
metadata = Column(JSON, nullable=True)
# Relationships
tenant = relationship("Tenant", back_populates="users")
# Indexes
__table_args__ = (
Index('idx_tenant_user', 'tenant_id', 'user_id'),
Index('idx_user_tenants', 'user_id'),
{'schema': 'aitbc'}
)
class TenantQuota(Base):
"""Resource quotas for tenants"""
__tablename__ = "tenant_quotas"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign key
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
# Quota definitions
resource_type = Column(String(100), nullable=False) # gpu_hours, storage_gb, api_calls
limit_value = Column(Numeric(20, 4), nullable=False) # Maximum allowed
used_value = Column(Numeric(20, 4), nullable=False, default=0) # Current usage
# Time period
period_type = Column(String(50), nullable=False, default="monthly") # daily, weekly, monthly
period_start = Column(DateTime(timezone=True), nullable=False)
period_end = Column(DateTime(timezone=True), nullable=False)
# Status
is_active = Column(Boolean, nullable=False, default=True)
# Relationships
tenant = relationship("Tenant", back_populates="quotas")
# Indexes
__table_args__ = (
Index('idx_tenant_quota', 'tenant_id', 'resource_type', 'period_start'),
Index('idx_quota_period', 'period_start', 'period_end'),
{'schema': 'aitbc'}
)
class UsageRecord(Base):
"""Usage tracking records for billing"""
__tablename__ = "usage_records"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign key
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
# Usage details
resource_type = Column(String(100), nullable=False) # gpu_hours, storage_gb, api_calls
resource_id = Column(String(255), nullable=True) # Specific resource ID
quantity = Column(Numeric(20, 4), nullable=False)
unit = Column(String(50), nullable=False) # hours, gb, calls
# Cost information
unit_price = Column(Numeric(10, 4), nullable=False)
total_cost = Column(Numeric(20, 4), nullable=False)
currency = Column(String(10), nullable=False, default="USD")
# Time tracking
usage_start = Column(DateTime(timezone=True), nullable=False)
usage_end = Column(DateTime(timezone=True), nullable=False)
recorded_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
# Metadata
job_id = Column(String(255), nullable=True) # Associated job if applicable
metadata = Column(JSON, nullable=True)
# Relationships
tenant = relationship("Tenant", back_populates="usage_records")
# Indexes
__table_args__ = (
Index('idx_tenant_usage', 'tenant_id', 'usage_start'),
Index('idx_usage_type', 'resource_type', 'usage_start'),
Index('idx_usage_job', 'job_id'),
{'schema': 'aitbc'}
)
class Invoice(Base):
"""Billing invoices for tenants"""
__tablename__ = "invoices"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign key
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
# Invoice details
invoice_number = Column(String(100), unique=True, nullable=False, index=True)
status = Column(String(50), nullable=False, default="draft")
# Period
period_start = Column(DateTime(timezone=True), nullable=False)
period_end = Column(DateTime(timezone=True), nullable=False)
due_date = Column(DateTime(timezone=True), nullable=False)
# Amounts
subtotal = Column(Numeric(20, 4), nullable=False)
tax_amount = Column(Numeric(20, 4), nullable=False, default=0)
total_amount = Column(Numeric(20, 4), nullable=False)
currency = Column(String(10), nullable=False, default="USD")
# Breakdown
line_items = Column(JSON, nullable=False, default=[])
# Payment
paid_at = Column(DateTime(timezone=True), nullable=True)
payment_method = Column(String(100), nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# Metadata
metadata = Column(JSON, nullable=True)
# Indexes
__table_args__ = (
Index('idx_invoice_tenant', 'tenant_id', 'period_start'),
Index('idx_invoice_status', 'status'),
Index('idx_invoice_due', 'due_date'),
{'schema': 'aitbc'}
)
class TenantApiKey(Base):
"""API keys for tenant authentication"""
__tablename__ = "tenant_api_keys"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign key
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
# Key details
key_id = Column(String(100), unique=True, nullable=False, index=True)
key_hash = Column(String(255), unique=True, nullable=False, index=True)
key_prefix = Column(String(20), nullable=False) # First few characters for identification
# Permissions and restrictions
permissions = Column(JSON, nullable=False, default=[])
rate_limit = Column(Integer, nullable=True) # Requests per minute
allowed_ips = Column(JSON, nullable=True) # IP whitelist
# Status
is_active = Column(Boolean, nullable=False, default=True)
expires_at = Column(DateTime(timezone=True), nullable=True)
last_used_at = Column(DateTime(timezone=True), nullable=True)
# Metadata
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
created_by = Column(String(255), nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
revoked_at = Column(DateTime(timezone=True), nullable=True)
# Indexes
__table_args__ = (
Index('idx_api_key_tenant', 'tenant_id', 'is_active'),
Index('idx_api_key_hash', 'key_hash'),
{'schema': 'aitbc'}
)
class TenantAuditLog(Base):
"""Audit logs for tenant activities"""
__tablename__ = "tenant_audit_logs"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign key
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
# Event details
event_type = Column(String(100), nullable=False, index=True)
event_category = Column(String(50), nullable=False, index=True)
actor_id = Column(String(255), nullable=False) # User who performed action
actor_type = Column(String(50), nullable=False) # user, api_key, system
# Target information
resource_type = Column(String(100), nullable=False)
resource_id = Column(String(255), nullable=True)
# Event data
old_values = Column(JSON, nullable=True)
new_values = Column(JSON, nullable=True)
metadata = Column(JSON, nullable=True)
# Request context
ip_address = Column(String(45), nullable=True)
user_agent = Column(Text, nullable=True)
api_key_id = Column(String(100), nullable=True)
# Timestamp
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
# Indexes
__table_args__ = (
Index('idx_audit_tenant', 'tenant_id', 'created_at'),
Index('idx_audit_actor', 'actor_id', 'event_type'),
Index('idx_audit_resource', 'resource_type', 'resource_id'),
{'schema': 'aitbc'}
)
class TenantMetric(Base):
"""Tenant-specific metrics and monitoring data"""
__tablename__ = "tenant_metrics"
# Primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign key
tenant_id = Column(UUID(as_uuid=True), ForeignKey('aitbc.tenants.id'), nullable=False)
# Metric details
metric_name = Column(String(100), nullable=False, index=True)
metric_type = Column(String(50), nullable=False) # counter, gauge, histogram
# Value
value = Column(Numeric(20, 4), nullable=False)
unit = Column(String(50), nullable=True)
# Dimensions
dimensions = Column(JSON, nullable=False, default={})
# Time
timestamp = Column(DateTime(timezone=True), nullable=False, index=True)
# Indexes
__table_args__ = (
Index('idx_metric_tenant', 'tenant_id', 'metric_name', 'timestamp'),
Index('idx_metric_time', 'timestamp'),
{'schema': 'aitbc'}
)

View File

@ -0,0 +1,547 @@
"""
Dynamic service registry models for AITBC
"""
from typing import Dict, List, Any, Optional, Union
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field, validator
class ServiceCategory(str, Enum):
"""Service categories"""
AI_ML = "ai_ml"
MEDIA_PROCESSING = "media_processing"
SCIENTIFIC_COMPUTING = "scientific_computing"
DATA_ANALYTICS = "data_analytics"
GAMING_ENTERTAINMENT = "gaming_entertainment"
DEVELOPMENT_TOOLS = "development_tools"
class ParameterType(str, Enum):
"""Parameter types"""
STRING = "string"
INTEGER = "integer"
FLOAT = "float"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
FILE = "file"
ENUM = "enum"
class PricingModel(str, Enum):
"""Pricing models"""
PER_UNIT = "per_unit" # per image, per minute, per token
PER_HOUR = "per_hour"
PER_GB = "per_gb"
PER_FRAME = "per_frame"
FIXED = "fixed"
CUSTOM = "custom"
class ParameterDefinition(BaseModel):
"""Parameter definition schema"""
name: str = Field(..., description="Parameter name")
type: ParameterType = Field(..., description="Parameter type")
required: bool = Field(True, description="Whether parameter is required")
description: str = Field(..., description="Parameter description")
default: Optional[Any] = Field(None, description="Default value")
min_value: Optional[Union[int, float]] = Field(None, description="Minimum value")
max_value: Optional[Union[int, float]] = Field(None, description="Maximum value")
options: Optional[List[str]] = Field(None, description="Available options for enum type")
validation: Optional[Dict[str, Any]] = Field(None, description="Custom validation rules")
class HardwareRequirement(BaseModel):
"""Hardware requirement definition"""
component: str = Field(..., description="Component type (gpu, cpu, ram, etc.)")
min_value: Union[str, int, float] = Field(..., description="Minimum requirement")
recommended: Optional[Union[str, int, float]] = Field(None, description="Recommended value")
unit: Optional[str] = Field(None, description="Unit (GB, MB, cores, etc.)")
class PricingTier(BaseModel):
"""Pricing tier definition"""
name: str = Field(..., description="Tier name")
model: PricingModel = Field(..., description="Pricing model")
unit_price: float = Field(..., ge=0, description="Price per unit")
min_charge: Optional[float] = Field(None, ge=0, description="Minimum charge")
currency: str = Field("AITBC", description="Currency code")
description: Optional[str] = Field(None, description="Tier description")
class ServiceDefinition(BaseModel):
"""Complete service definition"""
id: str = Field(..., description="Unique service identifier")
name: str = Field(..., description="Human-readable service name")
category: ServiceCategory = Field(..., description="Service category")
description: str = Field(..., description="Service description")
version: str = Field("1.0.0", description="Service version")
icon: Optional[str] = Field(None, description="Icon emoji or URL")
# Input/Output
input_parameters: List[ParameterDefinition] = Field(..., description="Input parameters")
output_schema: Dict[str, Any] = Field(..., description="Output schema")
# Hardware requirements
requirements: List[HardwareRequirement] = Field(..., description="Hardware requirements")
# Pricing
pricing: List[PricingTier] = Field(..., description="Available pricing tiers")
# Capabilities
capabilities: List[str] = Field(default_factory=list, description="Service capabilities")
tags: List[str] = Field(default_factory=list, description="Search tags")
# Limits
max_concurrent: int = Field(1, ge=1, le=100, description="Max concurrent jobs")
timeout_seconds: int = Field(3600, ge=60, description="Default timeout")
# Metadata
provider: Optional[str] = Field(None, description="Service provider")
documentation_url: Optional[str] = Field(None, description="Documentation URL")
example_usage: Optional[Dict[str, Any]] = Field(None, description="Example usage")
@validator('id')
def validate_id(cls, v):
if not v or not v.replace('_', '').replace('-', '').isalnum():
raise ValueError('Service ID must contain only alphanumeric characters, hyphens, and underscores')
return v.lower()
class ServiceRegistry(BaseModel):
"""Service registry containing all available services"""
version: str = Field("1.0.0", description="Registry version")
last_updated: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
services: Dict[str, ServiceDefinition] = Field(..., description="Service definitions by ID")
def get_service(self, service_id: str) -> Optional[ServiceDefinition]:
"""Get service by ID"""
return self.services.get(service_id)
def get_services_by_category(self, category: ServiceCategory) -> List[ServiceDefinition]:
"""Get all services in a category"""
return [s for s in self.services.values() if s.category == category]
def search_services(self, query: str) -> List[ServiceDefinition]:
"""Search services by name, description, or tags"""
query = query.lower()
results = []
for service in self.services.values():
if (query in service.name.lower() or
query in service.description.lower() or
any(query in tag.lower() for tag in service.tags)):
results.append(service)
return results
# Predefined service templates
AI_ML_SERVICES = {
"llm_inference": ServiceDefinition(
id="llm_inference",
name="LLM Inference",
category=ServiceCategory.AI_ML,
description="Run inference on large language models",
icon="🤖",
input_parameters=[
ParameterDefinition(
name="model",
type=ParameterType.ENUM,
required=True,
description="Model to use for inference",
options=["llama-7b", "llama-13b", "llama-70b", "mistral-7b", "mixtral-8x7b", "codellama-7b", "codellama-13b", "codellama-34b", "falcon-7b", "falcon-40b"]
),
ParameterDefinition(
name="prompt",
type=ParameterType.STRING,
required=True,
description="Input prompt text",
min_value=1,
max_value=10000
),
ParameterDefinition(
name="max_tokens",
type=ParameterType.INTEGER,
required=False,
description="Maximum tokens to generate",
default=256,
min_value=1,
max_value=4096
),
ParameterDefinition(
name="temperature",
type=ParameterType.FLOAT,
required=False,
description="Sampling temperature",
default=0.7,
min_value=0.0,
max_value=2.0
),
ParameterDefinition(
name="stream",
type=ParameterType.BOOLEAN,
required=False,
description="Stream response",
default=False
)
],
output_schema={
"type": "object",
"properties": {
"text": {"type": "string"},
"tokens_used": {"type": "integer"},
"finish_reason": {"type": "string"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-4090"),
HardwareRequirement(component="vram", min_value=8, recommended=24, unit="GB"),
HardwareRequirement(component="cuda", min_value="11.8")
],
pricing=[
PricingTier(name="basic", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
PricingTier(name="premium", model=PricingModel.PER_UNIT, unit_price=0.002, min_charge=0.01)
],
capabilities=["generate", "stream", "chat", "completion"],
tags=["llm", "text", "generation", "ai", "nlp"],
max_concurrent=2,
timeout_seconds=300
),
"image_generation": ServiceDefinition(
id="image_generation",
name="Image Generation",
category=ServiceCategory.AI_ML,
description="Generate images from text prompts using diffusion models",
icon="🎨",
input_parameters=[
ParameterDefinition(
name="model",
type=ParameterType.ENUM,
required=True,
description="Image generation model",
options=["stable-diffusion-1.5", "stable-diffusion-2.1", "stable-diffusion-xl", "sdxl-turbo", "dall-e-2", "dall-e-3", "midjourney-v5"]
),
ParameterDefinition(
name="prompt",
type=ParameterType.STRING,
required=True,
description="Text prompt for image generation",
max_value=1000
),
ParameterDefinition(
name="negative_prompt",
type=ParameterType.STRING,
required=False,
description="Negative prompt",
max_value=1000
),
ParameterDefinition(
name="width",
type=ParameterType.INTEGER,
required=False,
description="Image width",
default=512,
options=[256, 512, 768, 1024, 1536, 2048]
),
ParameterDefinition(
name="height",
type=ParameterType.INTEGER,
required=False,
description="Image height",
default=512,
options=[256, 512, 768, 1024, 1536, 2048]
),
ParameterDefinition(
name="num_images",
type=ParameterType.INTEGER,
required=False,
description="Number of images to generate",
default=1,
min_value=1,
max_value=4
),
ParameterDefinition(
name="steps",
type=ParameterType.INTEGER,
required=False,
description="Number of inference steps",
default=20,
min_value=1,
max_value=100
)
],
output_schema={
"type": "object",
"properties": {
"images": {"type": "array", "items": {"type": "string"}},
"parameters": {"type": "object"},
"generation_time": {"type": "number"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-4090"),
HardwareRequirement(component="vram", min_value=4, recommended=16, unit="GB"),
HardwareRequirement(component="cuda", min_value="11.8")
],
pricing=[
PricingTier(name="standard", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.01),
PricingTier(name="hd", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.02),
PricingTier(name="4k", model=PricingModel.PER_UNIT, unit_price=0.05, min_charge=0.05)
],
capabilities=["txt2img", "img2img", "inpainting", "outpainting"],
tags=["image", "generation", "diffusion", "ai", "art"],
max_concurrent=1,
timeout_seconds=600
),
"video_generation": ServiceDefinition(
id="video_generation",
name="Video Generation",
category=ServiceCategory.AI_ML,
description="Generate videos from text or images",
icon="🎬",
input_parameters=[
ParameterDefinition(
name="model",
type=ParameterType.ENUM,
required=True,
description="Video generation model",
options=["sora", "runway-gen2", "pika-labs", "stable-video-diffusion", "make-a-video"]
),
ParameterDefinition(
name="prompt",
type=ParameterType.STRING,
required=True,
description="Text prompt for video generation",
max_value=500
),
ParameterDefinition(
name="duration_seconds",
type=ParameterType.INTEGER,
required=False,
description="Video duration in seconds",
default=4,
min_value=1,
max_value=30
),
ParameterDefinition(
name="fps",
type=ParameterType.INTEGER,
required=False,
description="Frames per second",
default=24,
options=[12, 24, 30]
),
ParameterDefinition(
name="resolution",
type=ParameterType.ENUM,
required=False,
description="Video resolution",
default="720p",
options=["480p", "720p", "1080p", "4k"]
)
],
output_schema={
"type": "object",
"properties": {
"video_url": {"type": "string"},
"thumbnail_url": {"type": "string"},
"duration": {"type": "number"},
"resolution": {"type": "string"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
HardwareRequirement(component="cuda", min_value="11.8")
],
pricing=[
PricingTier(name="short", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=0.1),
PricingTier(name="medium", model=PricingModel.PER_UNIT, unit_price=0.25, min_charge=0.25),
PricingTier(name="long", model=PricingModel.PER_UNIT, unit_price=0.5, min_charge=0.5)
],
capabilities=["txt2video", "img2video", "video-editing"],
tags=["video", "generation", "ai", "animation"],
max_concurrent=1,
timeout_seconds=1800
),
"speech_recognition": ServiceDefinition(
id="speech_recognition",
name="Speech Recognition",
category=ServiceCategory.AI_ML,
description="Transcribe audio to text using speech recognition models",
icon="🎙️",
input_parameters=[
ParameterDefinition(
name="model",
type=ParameterType.ENUM,
required=True,
description="Speech recognition model",
options=["whisper-tiny", "whisper-base", "whisper-small", "whisper-medium", "whisper-large", "whisper-large-v2", "whisper-large-v3"]
),
ParameterDefinition(
name="audio_file",
type=ParameterType.FILE,
required=True,
description="Audio file to transcribe"
),
ParameterDefinition(
name="language",
type=ParameterType.ENUM,
required=False,
description="Audio language",
default="auto",
options=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh", "ar", "hi"]
),
ParameterDefinition(
name="task",
type=ParameterType.ENUM,
required=False,
description="Task type",
default="transcribe",
options=["transcribe", "translate"]
)
],
output_schema={
"type": "object",
"properties": {
"text": {"type": "string"},
"language": {"type": "string"},
"segments": {"type": "array"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3060"),
HardwareRequirement(component="vram", min_value=1, recommended=4, unit="GB")
],
pricing=[
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01)
],
capabilities=["transcribe", "translate", "timestamp", "speaker-diarization"],
tags=["speech", "audio", "transcription", "whisper"],
max_concurrent=2,
timeout_seconds=600
),
"computer_vision": ServiceDefinition(
id="computer_vision",
name="Computer Vision",
category=ServiceCategory.AI_ML,
description="Analyze images with computer vision models",
icon="👁️",
input_parameters=[
ParameterDefinition(
name="task",
type=ParameterType.ENUM,
required=True,
description="Vision task",
options=["object-detection", "classification", "face-recognition", "segmentation", "ocr"]
),
ParameterDefinition(
name="model",
type=ParameterType.ENUM,
required=True,
description="Vision model",
options=["yolo-v8", "resnet-50", "efficientnet", "vit", "face-net", "tesseract"]
),
ParameterDefinition(
name="image",
type=ParameterType.FILE,
required=True,
description="Input image"
),
ParameterDefinition(
name="confidence_threshold",
type=ParameterType.FLOAT,
required=False,
description="Confidence threshold",
default=0.5,
min_value=0.0,
max_value=1.0
)
],
output_schema={
"type": "object",
"properties": {
"detections": {"type": "array"},
"labels": {"type": "array"},
"confidence_scores": {"type": "array"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3060"),
HardwareRequirement(component="vram", min_value=2, recommended=8, unit="GB")
],
pricing=[
PricingTier(name="per_image", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.01)
],
capabilities=["detection", "classification", "recognition", "segmentation", "ocr"],
tags=["vision", "image", "analysis", "ai", "detection"],
max_concurrent=4,
timeout_seconds=120
),
"recommendation_system": ServiceDefinition(
id="recommendation_system",
name="Recommendation System",
category=ServiceCategory.AI_ML,
description="Generate personalized recommendations",
icon="🎯",
input_parameters=[
ParameterDefinition(
name="model_type",
type=ParameterType.ENUM,
required=True,
description="Recommendation model type",
options=["collaborative", "content-based", "hybrid", "deep-learning"]
),
ParameterDefinition(
name="user_id",
type=ParameterType.STRING,
required=True,
description="User identifier"
),
ParameterDefinition(
name="item_data",
type=ParameterType.ARRAY,
required=True,
description="Item catalog data"
),
ParameterDefinition(
name="num_recommendations",
type=ParameterType.INTEGER,
required=False,
description="Number of recommendations",
default=10,
min_value=1,
max_value=100
)
],
output_schema={
"type": "object",
"properties": {
"recommendations": {"type": "array"},
"scores": {"type": "array"},
"explanation": {"type": "string"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=4, recommended=12, unit="GB"),
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
],
pricing=[
PricingTier(name="per_request", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.01),
PricingTier(name="bulk", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.1)
],
capabilities=["personalization", "real-time", "batch", "ab-testing"],
tags=["recommendation", "personalization", "ml", "ecommerce"],
max_concurrent=10,
timeout_seconds=60
)
}

View File

@ -0,0 +1,286 @@
"""
Data analytics service definitions
"""
from typing import Dict, List, Any, Union
from .registry import (
ServiceDefinition,
ServiceCategory,
ParameterDefinition,
ParameterType,
HardwareRequirement,
PricingTier,
PricingModel
)
DATA_ANALYTICS_SERVICES = {
"big_data_processing": ServiceDefinition(
id="big_data_processing",
name="Big Data Processing",
category=ServiceCategory.DATA_ANALYTICS,
description="GPU-accelerated ETL and data processing with RAPIDS",
icon="📊",
input_parameters=[
ParameterDefinition(
name="operation",
type=ParameterType.ENUM,
required=True,
description="Processing operation",
options=["etl", "aggregate", "join", "filter", "transform", "clean"]
),
ParameterDefinition(
name="data_source",
type=ParameterType.STRING,
required=True,
description="Data source URL or connection string"
),
ParameterDefinition(
name="query",
type=ParameterType.STRING,
required=True,
description="SQL or data processing query"
),
ParameterDefinition(
name="output_format",
type=ParameterType.ENUM,
required=False,
description="Output format",
default="parquet",
options=["parquet", "csv", "json", "delta", "orc"]
),
ParameterDefinition(
name="partition_by",
type=ParameterType.ARRAY,
required=False,
description="Partition columns",
items={"type": "string"}
)
],
output_schema={
"type": "object",
"properties": {
"output_url": {"type": "string"},
"row_count": {"type": "integer"},
"columns": {"type": "array"},
"processing_stats": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="ram", min_value=32, recommended=128, unit="GB"),
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
],
pricing=[
PricingTier(name="per_gb", model=PricingModel.PER_GB, unit_price=0.01, min_charge=0.1),
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1),
PricingTier(name="enterprise", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.5)
],
capabilities=["gpu-sql", "etl", "streaming", "distributed"],
tags=["bigdata", "etl", "rapids", "spark", "sql"],
max_concurrent=5,
timeout_seconds=3600
),
"real_time_analytics": ServiceDefinition(
id="real_time_analytics",
name="Real-time Analytics",
category=ServiceCategory.DATA_ANALYTICS,
description="Stream processing and real-time analytics with GPU acceleration",
icon="",
input_parameters=[
ParameterDefinition(
name="stream_source",
type=ParameterType.STRING,
required=True,
description="Stream source (Kafka, Kinesis, etc.)"
),
ParameterDefinition(
name="query",
type=ParameterType.STRING,
required=True,
description="Stream processing query"
),
ParameterDefinition(
name="window_size",
type=ParameterType.STRING,
required=False,
description="Window size (e.g., 1m, 5m, 1h)",
default="5m"
),
ParameterDefinition(
name="aggregations",
type=ParameterType.ARRAY,
required=True,
description="Aggregation functions",
items={"type": "string"}
),
ParameterDefinition(
name="output_sink",
type=ParameterType.STRING,
required=True,
description="Output sink for results"
)
],
output_schema={
"type": "object",
"properties": {
"stream_id": {"type": "string"},
"throughput": {"type": "number"},
"latency_ms": {"type": "integer"},
"metrics": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
HardwareRequirement(component="network", min_value="10Gbps", recommended="100Gbps"),
HardwareRequirement(component="ram", min_value=64, recommended=256, unit="GB")
],
pricing=[
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=2),
PricingTier(name="per_million_events", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=1),
PricingTier(name="high_throughput", model=PricingModel.PER_HOUR, unit_price=5, min_charge=5)
],
capabilities=["streaming", "windowing", "aggregation", "cep"],
tags=["streaming", "real-time", "analytics", "kafka", "flink"],
max_concurrent=10,
timeout_seconds=86400 # 24 hours
),
"graph_analytics": ServiceDefinition(
id="graph_analytics",
name="Graph Analytics",
category=ServiceCategory.DATA_ANALYTICS,
description="Network analysis and graph algorithms on GPU",
icon="🕸️",
input_parameters=[
ParameterDefinition(
name="algorithm",
type=ParameterType.ENUM,
required=True,
description="Graph algorithm",
options=["pagerank", "community-detection", "shortest-path", "triangles", "clustering", "centrality"]
),
ParameterDefinition(
name="graph_data",
type=ParameterType.FILE,
required=True,
description="Graph data file (edges list, adjacency matrix, etc.)"
),
ParameterDefinition(
name="graph_format",
type=ParameterType.ENUM,
required=False,
description="Graph format",
default="edges",
options=["edges", "adjacency", "csr", "metis"]
),
ParameterDefinition(
name="parameters",
type=ParameterType.OBJECT,
required=False,
description="Algorithm-specific parameters"
),
ParameterDefinition(
name="num_vertices",
type=ParameterType.INTEGER,
required=False,
description="Number of vertices",
min_value=1
)
],
output_schema={
"type": "object",
"properties": {
"results": {"type": "array"},
"statistics": {"type": "object"},
"graph_metrics": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3090"),
HardwareRequirement(component="vram", min_value=8, recommended=24, unit="GB"),
HardwareRequirement(component="ram", min_value=16, recommended=64, unit="GB")
],
pricing=[
PricingTier(name="per_million_edges", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1),
PricingTier(name="large_graph", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.5)
],
capabilities=["gpu-graph", "algorithms", "network-analysis", "fraud-detection"],
tags=["graph", "network", "analytics", "pagerank", "fraud"],
max_concurrent=5,
timeout_seconds=3600
),
"time_series_analysis": ServiceDefinition(
id="time_series_analysis",
name="Time Series Analysis",
category=ServiceCategory.DATA_ANALYTICS,
description="Analyze time series data with GPU-accelerated algorithms",
icon="📈",
input_parameters=[
ParameterDefinition(
name="analysis_type",
type=ParameterType.ENUM,
required=True,
description="Analysis type",
options=["forecasting", "anomaly-detection", "decomposition", "seasonality", "trend"]
),
ParameterDefinition(
name="time_series_data",
type=ParameterType.FILE,
required=True,
description="Time series data file"
),
ParameterDefinition(
name="model",
type=ParameterType.ENUM,
required=True,
description="Analysis model",
options=["arima", "prophet", "lstm", "transformer", "holt-winters", "var"]
),
ParameterDefinition(
name="forecast_horizon",
type=ParameterType.INTEGER,
required=False,
description="Forecast horizon",
default=30,
min_value=1,
max_value=365
),
ParameterDefinition(
name="frequency",
type=ParameterType.STRING,
required=False,
description="Data frequency (D, H, M, S)",
default="D"
)
],
output_schema={
"type": "object",
"properties": {
"forecast": {"type": "array"},
"confidence_intervals": {"type": "array"},
"model_metrics": {"type": "object"},
"anomalies": {"type": "array"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
],
pricing=[
PricingTier(name="per_1k_points", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
PricingTier(name="per_forecast", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
PricingTier(name="enterprise", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1)
],
capabilities=["forecasting", "anomaly-detection", "decomposition", "seasonality"],
tags=["time-series", "forecasting", "anomaly", "arima", "lstm"],
max_concurrent=10,
timeout_seconds=1800
)
}

View File

@ -0,0 +1,408 @@
"""
Development tools service definitions
"""
from typing import Dict, List, Any, Union
from .registry import (
ServiceDefinition,
ServiceCategory,
ParameterDefinition,
ParameterType,
HardwareRequirement,
PricingTier,
PricingModel
)
DEVTOOLS_SERVICES = {
"gpu_compilation": ServiceDefinition(
id="gpu_compilation",
name="GPU-Accelerated Compilation",
category=ServiceCategory.DEVELOPMENT_TOOLS,
description="Compile code with GPU acceleration (CUDA, HIP, OpenCL)",
icon="⚙️",
input_parameters=[
ParameterDefinition(
name="language",
type=ParameterType.ENUM,
required=True,
description="Programming language",
options=["cpp", "cuda", "hip", "opencl", "metal", "sycl"]
),
ParameterDefinition(
name="source_files",
type=ParameterType.ARRAY,
required=True,
description="Source code files",
items={"type": "string"}
),
ParameterDefinition(
name="build_type",
type=ParameterType.ENUM,
required=False,
description="Build type",
default="release",
options=["debug", "release", "relwithdebinfo"]
),
ParameterDefinition(
name="target_arch",
type=ParameterType.ENUM,
required=False,
description="Target architecture",
default="sm_70",
options=["sm_60", "sm_70", "sm_80", "sm_86", "sm_89", "sm_90"]
),
ParameterDefinition(
name="optimization_level",
type=ParameterType.ENUM,
required=False,
description="Optimization level",
default="O2",
options=["O0", "O1", "O2", "O3", "Os"]
),
ParameterDefinition(
name="parallel_jobs",
type=ParameterType.INTEGER,
required=False,
description="Number of parallel compilation jobs",
default=4,
min_value=1,
max_value=64
)
],
output_schema={
"type": "object",
"properties": {
"binary_url": {"type": "string"},
"build_log": {"type": "string"},
"compilation_time": {"type": "number"},
"binary_size": {"type": "integer"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=4, recommended=8, unit="GB"),
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores"),
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB"),
HardwareRequirement(component="cuda", min_value="11.8")
],
pricing=[
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
PricingTier(name="per_file", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
PricingTier(name="enterprise", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1)
],
capabilities=["cuda", "hip", "parallel-compilation", "incremental"],
tags=["compilation", "cuda", "gpu", "cpp", "build"],
max_concurrent=5,
timeout_seconds=1800
),
"model_training": ServiceDefinition(
id="model_training",
name="ML Model Training",
category=ServiceCategory.DEVELOPMENT_TOOLS,
description="Fine-tune or train machine learning models on client data",
icon="🧠",
input_parameters=[
ParameterDefinition(
name="model_type",
type=ParameterType.ENUM,
required=True,
description="Model type",
options=["transformer", "cnn", "rnn", "gan", "diffusion", "custom"]
),
ParameterDefinition(
name="base_model",
type=ParameterType.STRING,
required=False,
description="Base model to fine-tune"
),
ParameterDefinition(
name="training_data",
type=ParameterType.FILE,
required=True,
description="Training dataset"
),
ParameterDefinition(
name="validation_data",
type=ParameterType.FILE,
required=False,
description="Validation dataset"
),
ParameterDefinition(
name="epochs",
type=ParameterType.INTEGER,
required=False,
description="Number of training epochs",
default=10,
min_value=1,
max_value=1000
),
ParameterDefinition(
name="batch_size",
type=ParameterType.INTEGER,
required=False,
description="Batch size",
default=32,
min_value=1,
max_value=1024
),
ParameterDefinition(
name="learning_rate",
type=ParameterType.FLOAT,
required=False,
description="Learning rate",
default=0.001,
min_value=0.00001,
max_value=1
),
ParameterDefinition(
name="hyperparameters",
type=ParameterType.OBJECT,
required=False,
description="Additional hyperparameters"
)
],
output_schema={
"type": "object",
"properties": {
"model_url": {"type": "string"},
"training_metrics": {"type": "object"},
"loss_curves": {"type": "array"},
"validation_scores": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
HardwareRequirement(component="cpu", min_value=16, recommended=32, unit="cores"),
HardwareRequirement(component="ram", min_value=32, recommended=128, unit="GB"),
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
],
pricing=[
PricingTier(name="per_epoch", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=1),
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=2),
PricingTier(name="enterprise", model=PricingModel.PER_UNIT, unit_price=0.05, min_charge=0.5)
],
capabilities=["fine-tuning", "training", "hyperparameter-tuning", "distributed"],
tags=["ml", "training", "fine-tuning", "pytorch", "tensorflow"],
max_concurrent=2,
timeout_seconds=86400 # 24 hours
),
"data_processing": ServiceDefinition(
id="data_processing",
name="Large Dataset Processing",
category=ServiceCategory.DEVELOPMENT_TOOLS,
description="Preprocess and transform large datasets",
icon="📦",
input_parameters=[
ParameterDefinition(
name="operation",
type=ParameterType.ENUM,
required=True,
description="Processing operation",
options=["clean", "transform", "normalize", "augment", "split", "encode"]
),
ParameterDefinition(
name="input_data",
type=ParameterType.FILE,
required=True,
description="Input dataset"
),
ParameterDefinition(
name="output_format",
type=ParameterType.ENUM,
required=False,
description="Output format",
default="parquet",
options=["csv", "json", "parquet", "hdf5", "feather", "pickle"]
),
ParameterDefinition(
name="chunk_size",
type=ParameterType.INTEGER,
required=False,
description="Processing chunk size",
default=10000,
min_value=100,
max_value=1000000
),
ParameterDefinition(
name="parameters",
type=ParameterType.OBJECT,
required=False,
description="Operation-specific parameters"
)
],
output_schema={
"type": "object",
"properties": {
"output_url": {"type": "string"},
"processing_stats": {"type": "object"},
"data_quality": {"type": "object"},
"row_count": {"type": "integer"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
HardwareRequirement(component="vram", min_value=4, recommended=16, unit="GB"),
HardwareRequirement(component="ram", min_value=16, recommended=64, unit="GB"),
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
],
pricing=[
PricingTier(name="per_gb", model=PricingModel.PER_GB, unit_price=0.01, min_charge=0.1),
PricingTier(name="per_million_rows", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
PricingTier(name="enterprise", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1)
],
capabilities=["gpu-processing", "parallel", "streaming", "validation"],
tags=["data", "preprocessing", "etl", "cleaning", "transformation"],
max_concurrent=5,
timeout_seconds=3600
),
"simulation_testing": ServiceDefinition(
id="simulation_testing",
name="Hardware-in-the-Loop Testing",
category=ServiceCategory.DEVELOPMENT_TOOLS,
description="Run hardware simulations and testing workflows",
icon="🔬",
input_parameters=[
ParameterDefinition(
name="test_type",
type=ParameterType.ENUM,
required=True,
description="Test type",
options=["hardware", "firmware", "software", "integration", "performance"]
),
ParameterDefinition(
name="test_suite",
type=ParameterType.FILE,
required=True,
description="Test suite configuration"
),
ParameterDefinition(
name="hardware_config",
type=ParameterType.OBJECT,
required=True,
description="Hardware configuration"
),
ParameterDefinition(
name="duration",
type=ParameterType.INTEGER,
required=False,
description="Test duration in hours",
default=1,
min_value=0.1,
max_value=168 # 1 week
),
ParameterDefinition(
name="parallel_tests",
type=ParameterType.INTEGER,
required=False,
description="Number of parallel tests",
default=1,
min_value=1,
max_value=10
)
],
output_schema={
"type": "object",
"properties": {
"test_results": {"type": "array"},
"performance_metrics": {"type": "object"},
"failure_logs": {"type": "array"},
"coverage_report": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
HardwareRequirement(component="cpu", min_value=16, recommended=32, unit="cores"),
HardwareRequirement(component="ram", min_value=32, recommended=128, unit="GB"),
HardwareRequirement(component="storage", min_value=100, recommended=500, unit="GB")
],
pricing=[
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=1),
PricingTier(name="per_test", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=0.5),
PricingTier(name="continuous", model=PricingModel.PER_HOUR, unit_price=5, min_charge=5)
],
capabilities=["hardware-simulation", "automated-testing", "performance", "debugging"],
tags=["testing", "simulation", "hardware", "hil", "verification"],
max_concurrent=3,
timeout_seconds=604800 # 1 week
),
"code_generation": ServiceDefinition(
id="code_generation",
name="AI Code Generation",
category=ServiceCategory.DEVELOPMENT_TOOLS,
description="Generate code from natural language descriptions",
icon="💻",
input_parameters=[
ParameterDefinition(
name="language",
type=ParameterType.ENUM,
required=True,
description="Target programming language",
options=["python", "javascript", "cpp", "java", "go", "rust", "typescript", "sql"]
),
ParameterDefinition(
name="description",
type=ParameterType.STRING,
required=True,
description="Natural language description of code to generate",
max_value=2000
),
ParameterDefinition(
name="framework",
type=ParameterType.STRING,
required=False,
description="Target framework or library"
),
ParameterDefinition(
name="code_style",
type=ParameterType.ENUM,
required=False,
description="Code style preferences",
default="standard",
options=["standard", "functional", "oop", "minimalist"]
),
ParameterDefinition(
name="include_comments",
type=ParameterType.BOOLEAN,
required=False,
description="Include explanatory comments",
default=True
),
ParameterDefinition(
name="include_tests",
type=ParameterType.BOOLEAN,
required=False,
description="Generate unit tests",
default=False
)
],
output_schema={
"type": "object",
"properties": {
"generated_code": {"type": "string"},
"explanation": {"type": "string"},
"usage_example": {"type": "string"},
"test_code": {"type": "string"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="ram", min_value=8, recommended=16, unit="GB")
],
pricing=[
PricingTier(name="per_generation", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.01),
PricingTier(name="per_100_lines", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
PricingTier(name="with_tests", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.02)
],
capabilities=["code-gen", "documentation", "test-gen", "refactoring"],
tags=["code", "generation", "ai", "copilot", "automation"],
max_concurrent=10,
timeout_seconds=120
)
}

View File

@ -0,0 +1,307 @@
"""
Gaming & entertainment service definitions
"""
from typing import Dict, List, Any, Union
from .registry import (
ServiceDefinition,
ServiceCategory,
ParameterDefinition,
ParameterType,
HardwareRequirement,
PricingTier,
PricingModel
)
GAMING_SERVICES = {
"cloud_gaming": ServiceDefinition(
id="cloud_gaming",
name="Cloud Gaming Server",
category=ServiceCategory.GAMING_ENTERTAINMENT,
description="Host cloud gaming sessions with GPU streaming",
icon="🎮",
input_parameters=[
ParameterDefinition(
name="game",
type=ParameterType.STRING,
required=True,
description="Game title or executable"
),
ParameterDefinition(
name="resolution",
type=ParameterType.ENUM,
required=True,
description="Streaming resolution",
options=["720p", "1080p", "1440p", "4k"]
),
ParameterDefinition(
name="fps",
type=ParameterType.INTEGER,
required=False,
description="Target frame rate",
default=60,
options=[30, 60, 120, 144]
),
ParameterDefinition(
name="session_duration",
type=ParameterType.INTEGER,
required=True,
description="Session duration in minutes",
min_value=15,
max_value=480
),
ParameterDefinition(
name="codec",
type=ParameterType.ENUM,
required=False,
description="Streaming codec",
default="h264",
options=["h264", "h265", "av1", "vp9"]
),
ParameterDefinition(
name="region",
type=ParameterType.STRING,
required=False,
description="Preferred server region"
)
],
output_schema={
"type": "object",
"properties": {
"stream_url": {"type": "string"},
"session_id": {"type": "string"},
"latency_ms": {"type": "integer"},
"quality_metrics": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="network", min_value="100Mbps", recommended="1Gbps"),
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores"),
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
],
pricing=[
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=0.5),
PricingTier(name="1080p", model=PricingModel.PER_HOUR, unit_price=1.5, min_charge=0.75),
PricingTier(name="4k", model=PricingModel.PER_HOUR, unit_price=3, min_charge=1.5)
],
capabilities=["low-latency", "game-streaming", "multiplayer", "saves"],
tags=["gaming", "cloud", "streaming", "nvidia", "gamepass"],
max_concurrent=1,
timeout_seconds=28800 # 8 hours
),
"game_asset_baking": ServiceDefinition(
id="game_asset_baking",
name="Game Asset Baking",
category=ServiceCategory.GAMING_ENTERTAINMENT,
description="Optimize and bake game assets (textures, meshes, materials)",
icon="🎨",
input_parameters=[
ParameterDefinition(
name="asset_type",
type=ParameterType.ENUM,
required=True,
description="Asset type",
options=["texture", "mesh", "material", "animation", "terrain"]
),
ParameterDefinition(
name="input_assets",
type=ParameterType.ARRAY,
required=True,
description="Input asset files",
items={"type": "string"}
),
ParameterDefinition(
name="target_platform",
type=ParameterType.ENUM,
required=True,
description="Target platform",
options=["pc", "mobile", "console", "web", "vr"]
),
ParameterDefinition(
name="optimization_level",
type=ParameterType.ENUM,
required=False,
description="Optimization level",
default="balanced",
options=["fast", "balanced", "maximum"]
),
ParameterDefinition(
name="texture_formats",
type=ParameterType.ARRAY,
required=False,
description="Output texture formats",
default=["dds", "astc"],
items={"type": "string"}
)
],
output_schema={
"type": "object",
"properties": {
"baked_assets": {"type": "array"},
"compression_stats": {"type": "object"},
"optimization_report": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB"),
HardwareRequirement(component="storage", min_value=50, recommended=500, unit="GB")
],
pricing=[
PricingTier(name="per_asset", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.1),
PricingTier(name="per_texture", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.05),
PricingTier(name="per_mesh", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.1)
],
capabilities=["texture-compression", "mesh-optimization", "lod-generation", "platform-specific"],
tags=["gamedev", "assets", "optimization", "textures", "meshes"],
max_concurrent=5,
timeout_seconds=1800
),
"physics_simulation": ServiceDefinition(
id="physics_simulation",
name="Game Physics Simulation",
category=ServiceCategory.GAMING_ENTERTAINMENT,
description="Run physics simulations for game development",
icon="⚛️",
input_parameters=[
ParameterDefinition(
name="engine",
type=ParameterType.ENUM,
required=True,
description="Physics engine",
options=["physx", "havok", "bullet", "box2d", "chipmunk"]
),
ParameterDefinition(
name="simulation_type",
type=ParameterType.ENUM,
required=True,
description="Simulation type",
options=["rigid-body", "soft-body", "fluid", "cloth", "destruction"]
),
ParameterDefinition(
name="scene_file",
type=ParameterType.FILE,
required=False,
description="Scene or level file"
),
ParameterDefinition(
name="parameters",
type=ParameterType.OBJECT,
required=True,
description="Physics parameters"
),
ParameterDefinition(
name="simulation_time",
type=ParameterType.FLOAT,
required=True,
description="Simulation duration in seconds",
min_value=0.1
),
ParameterDefinition(
name="record_frames",
type=ParameterType.BOOLEAN,
required=False,
description="Record animation frames",
default=False
)
],
output_schema={
"type": "object",
"properties": {
"simulation_data": {"type": "array"},
"animation_url": {"type": "string"},
"physics_stats": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores"),
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
],
pricing=[
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=0.5),
PricingTier(name="per_frame", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.1),
PricingTier(name="complex", model=PricingModel.PER_HOUR, unit_price=2, min_charge=1)
],
capabilities=["gpu-physics", "particle-systems", "destruction", "cloth"],
tags=["physics", "gamedev", "simulation", "physx", "havok"],
max_concurrent=3,
timeout_seconds=3600
),
"vr_ar_rendering": ServiceDefinition(
id="vr_ar_rendering",
name="VR/AR Rendering",
category=ServiceCategory.GAMING_ENTERTAINMENT,
description="Real-time 3D rendering for VR/AR applications",
icon="🥽",
input_parameters=[
ParameterDefinition(
name="platform",
type=ParameterType.ENUM,
required=True,
description="Target platform",
options=["oculus", "vive", "hololens", "magic-leap", "cardboard", "webxr"]
),
ParameterDefinition(
name="scene_file",
type=ParameterType.FILE,
required=True,
description="3D scene file"
),
ParameterDefinition(
name="render_quality",
type=ParameterType.ENUM,
required=False,
description="Render quality",
default="high",
options=["low", "medium", "high", "ultra"]
),
ParameterDefinition(
name="stereo_mode",
type=ParameterType.BOOLEAN,
required=False,
description="Stereo rendering",
default=True
),
ParameterDefinition(
name="target_fps",
type=ParameterType.INTEGER,
required=False,
description="Target frame rate",
default=90,
options=[60, 72, 90, 120, 144]
)
],
output_schema={
"type": "object",
"properties": {
"rendered_frames": {"type": "array"},
"performance_metrics": {"type": "object"},
"vr_package": {"type": "string"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores"),
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
],
pricing=[
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.5),
PricingTier(name="per_frame", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.1),
PricingTier(name="real-time", model=PricingModel.PER_HOUR, unit_price=5, min_charge=1)
],
capabilities=["stereo-rendering", "real-time", "low-latency", "tracking"],
tags=["vr", "ar", "rendering", "3d", "immersive"],
max_concurrent=2,
timeout_seconds=3600
)
}

View File

@ -0,0 +1,412 @@
"""
Media processing service definitions
"""
from typing import Dict, List, Any, Union
from .registry import (
ServiceDefinition,
ServiceCategory,
ParameterDefinition,
ParameterType,
HardwareRequirement,
PricingTier,
PricingModel
)
MEDIA_PROCESSING_SERVICES = {
"video_transcoding": ServiceDefinition(
id="video_transcoding",
name="Video Transcoding",
category=ServiceCategory.MEDIA_PROCESSING,
description="Transcode videos between formats using FFmpeg with GPU acceleration",
icon="🎬",
input_parameters=[
ParameterDefinition(
name="input_video",
type=ParameterType.FILE,
required=True,
description="Input video file"
),
ParameterDefinition(
name="output_format",
type=ParameterType.ENUM,
required=True,
description="Output video format",
options=["mp4", "webm", "avi", "mov", "mkv", "flv"]
),
ParameterDefinition(
name="codec",
type=ParameterType.ENUM,
required=False,
description="Video codec",
default="h264",
options=["h264", "h265", "vp9", "av1", "mpeg4"]
),
ParameterDefinition(
name="resolution",
type=ParameterType.STRING,
required=False,
description="Output resolution (e.g., 1920x1080)",
validation={"pattern": r"^\d+x\d+$"}
),
ParameterDefinition(
name="bitrate",
type=ParameterType.STRING,
required=False,
description="Target bitrate (e.g., 5M, 2500k)",
validation={"pattern": r"^\d+[kM]?$"}
),
ParameterDefinition(
name="fps",
type=ParameterType.INTEGER,
required=False,
description="Output frame rate",
min_value=1,
max_value=120
),
ParameterDefinition(
name="gpu_acceleration",
type=ParameterType.BOOLEAN,
required=False,
description="Use GPU acceleration",
default=True
)
],
output_schema={
"type": "object",
"properties": {
"output_url": {"type": "string"},
"metadata": {"type": "object"},
"duration": {"type": "number"},
"file_size": {"type": "integer"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
HardwareRequirement(component="vram", min_value=2, recommended=8, unit="GB"),
HardwareRequirement(component="ram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="storage", min_value=50, unit="GB")
],
pricing=[
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.01),
PricingTier(name="per_gb", model=PricingModel.PER_GB, unit_price=0.01, min_charge=0.01),
PricingTier(name="4k_premium", model=PricingModel.PER_UNIT, unit_price=0.02, min_charge=0.05)
],
capabilities=["transcode", "compress", "resize", "format-convert"],
tags=["video", "ffmpeg", "transcoding", "encoding", "gpu"],
max_concurrent=2,
timeout_seconds=3600
),
"video_streaming": ServiceDefinition(
id="video_streaming",
name="Live Video Streaming",
category=ServiceCategory.MEDIA_PROCESSING,
description="Real-time video transcoding for adaptive bitrate streaming",
icon="📡",
input_parameters=[
ParameterDefinition(
name="stream_url",
type=ParameterType.STRING,
required=True,
description="Input stream URL"
),
ParameterDefinition(
name="output_formats",
type=ParameterType.ARRAY,
required=True,
description="Output formats for adaptive streaming",
default=["720p", "1080p", "4k"]
),
ParameterDefinition(
name="duration_minutes",
type=ParameterType.INTEGER,
required=False,
description="Streaming duration in minutes",
default=60,
min_value=1,
max_value=480
),
ParameterDefinition(
name="protocol",
type=ParameterType.ENUM,
required=False,
description="Streaming protocol",
default="hls",
options=["hls", "dash", "rtmp", "webrtc"]
)
],
output_schema={
"type": "object",
"properties": {
"stream_url": {"type": "string"},
"playlist_url": {"type": "string"},
"bitrates": {"type": "array"},
"duration": {"type": "number"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="network", min_value="1Gbps", recommended="10Gbps"),
HardwareRequirement(component="ram", min_value=16, recommended=32, unit="GB")
],
pricing=[
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.5),
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=0.5, min_charge=0.5)
],
capabilities=["live-transcoding", "adaptive-bitrate", "multi-format", "low-latency"],
tags=["streaming", "live", "transcoding", "real-time"],
max_concurrent=5,
timeout_seconds=28800 # 8 hours
),
"3d_rendering": ServiceDefinition(
id="3d_rendering",
name="3D Rendering",
category=ServiceCategory.MEDIA_PROCESSING,
description="Render 3D scenes using Blender, Unreal Engine, or V-Ray",
icon="🎭",
input_parameters=[
ParameterDefinition(
name="engine",
type=ParameterType.ENUM,
required=True,
description="Rendering engine",
options=["blender-cycles", "blender-eevee", "unreal-engine", "v-ray", "octane"]
),
ParameterDefinition(
name="scene_file",
type=ParameterType.FILE,
required=True,
description="3D scene file (.blend, .ueproject, etc)"
),
ParameterDefinition(
name="resolution_x",
type=ParameterType.INTEGER,
required=False,
description="Output width",
default=1920,
min_value=1,
max_value=8192
),
ParameterDefinition(
name="resolution_y",
type=ParameterType.INTEGER,
required=False,
description="Output height",
default=1080,
min_value=1,
max_value=8192
),
ParameterDefinition(
name="samples",
type=ParameterType.INTEGER,
required=False,
description="Samples per pixel (path tracing)",
default=128,
min_value=1,
max_value=10000
),
ParameterDefinition(
name="frame_start",
type=ParameterType.INTEGER,
required=False,
description="Start frame for animation",
default=1,
min_value=1
),
ParameterDefinition(
name="frame_end",
type=ParameterType.INTEGER,
required=False,
description="End frame for animation",
default=1,
min_value=1
),
ParameterDefinition(
name="output_format",
type=ParameterType.ENUM,
required=False,
description="Output image format",
default="png",
options=["png", "jpg", "exr", "bmp", "tiff", "hdr"]
)
],
output_schema={
"type": "object",
"properties": {
"rendered_images": {"type": "array"},
"metadata": {"type": "object"},
"render_time": {"type": "number"},
"frame_count": {"type": "integer"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-4090"),
HardwareRequirement(component="vram", min_value=8, recommended=24, unit="GB"),
HardwareRequirement(component="ram", min_value=16, recommended=64, unit="GB"),
HardwareRequirement(component="cpu", min_value=8, recommended=16, unit="cores")
],
pricing=[
PricingTier(name="per_frame", model=PricingModel.PER_FRAME, unit_price=0.01, min_charge=0.1),
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=0.5, min_charge=0.5),
PricingTier(name="4k_premium", model=PricingModel.PER_FRAME, unit_price=0.05, min_charge=0.5)
],
capabilities=["path-tracing", "ray-tracing", "animation", "gpu-render"],
tags=["3d", "rendering", "blender", "unreal", "v-ray"],
max_concurrent=2,
timeout_seconds=7200
),
"image_processing": ServiceDefinition(
id="image_processing",
name="Batch Image Processing",
category=ServiceCategory.MEDIA_PROCESSING,
description="Process images in bulk with filters, effects, and format conversion",
icon="🖼️",
input_parameters=[
ParameterDefinition(
name="images",
type=ParameterType.ARRAY,
required=True,
description="Array of image files or URLs"
),
ParameterDefinition(
name="operations",
type=ParameterType.ARRAY,
required=True,
description="Processing operations to apply",
items={
"type": "object",
"properties": {
"type": {"type": "string"},
"params": {"type": "object"}
}
}
),
ParameterDefinition(
name="output_format",
type=ParameterType.ENUM,
required=False,
description="Output format",
default="jpg",
options=["jpg", "png", "webp", "avif", "tiff", "bmp"]
),
ParameterDefinition(
name="quality",
type=ParameterType.INTEGER,
required=False,
description="Output quality (1-100)",
default=90,
min_value=1,
max_value=100
),
ParameterDefinition(
name="resize",
type=ParameterType.STRING,
required=False,
description="Resize dimensions (e.g., 1920x1080, 50%)",
validation={"pattern": r"^\d+x\d+|^\d+%$"}
)
],
output_schema={
"type": "object",
"properties": {
"processed_images": {"type": "array"},
"count": {"type": "integer"},
"total_size": {"type": "integer"},
"processing_time": {"type": "number"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
HardwareRequirement(component="vram", min_value=1, recommended=4, unit="GB"),
HardwareRequirement(component="ram", min_value=4, recommended=16, unit="GB")
],
pricing=[
PricingTier(name="per_image", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.01),
PricingTier(name="bulk_100", model=PricingModel.PER_UNIT, unit_price=0.0005, min_charge=0.05),
PricingTier(name="bulk_1000", model=PricingModel.PER_UNIT, unit_price=0.0002, min_charge=0.2)
],
capabilities=["resize", "filter", "format-convert", "batch", "watermark"],
tags=["image", "processing", "batch", "filter", "conversion"],
max_concurrent=10,
timeout_seconds=600
),
"audio_processing": ServiceDefinition(
id="audio_processing",
name="Audio Processing",
category=ServiceCategory.MEDIA_PROCESSING,
description="Process audio files with effects, noise reduction, and format conversion",
icon="🎵",
input_parameters=[
ParameterDefinition(
name="audio_file",
type=ParameterType.FILE,
required=True,
description="Input audio file"
),
ParameterDefinition(
name="operations",
type=ParameterType.ARRAY,
required=True,
description="Audio operations to apply",
items={
"type": "object",
"properties": {
"type": {"type": "string"},
"params": {"type": "object"}
}
}
),
ParameterDefinition(
name="output_format",
type=ParameterType.ENUM,
required=False,
description="Output format",
default="mp3",
options=["mp3", "wav", "flac", "aac", "ogg", "m4a"]
),
ParameterDefinition(
name="sample_rate",
type=ParameterType.INTEGER,
required=False,
description="Output sample rate",
default=44100,
options=[22050, 44100, 48000, 96000, 192000]
),
ParameterDefinition(
name="bitrate",
type=ParameterType.INTEGER,
required=False,
description="Output bitrate (kbps)",
default=320,
options=[128, 192, 256, 320, 512, 1024]
)
],
output_schema={
"type": "object",
"properties": {
"output_url": {"type": "string"},
"metadata": {"type": "object"},
"duration": {"type": "number"},
"file_size": {"type": "integer"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="any", recommended="nvidia"),
HardwareRequirement(component="ram", min_value=2, recommended=8, unit="GB")
],
pricing=[
PricingTier(name="per_minute", model=PricingModel.PER_UNIT, unit_price=0.002, min_charge=0.01),
PricingTier(name="per_effect", model=PricingModel.PER_UNIT, unit_price=0.005, min_charge=0.01)
],
capabilities=["noise-reduction", "effects", "format-convert", "enhancement"],
tags=["audio", "processing", "effects", "noise-reduction"],
max_concurrent=5,
timeout_seconds=300
)
}

View File

@ -0,0 +1,406 @@
"""
Scientific computing service definitions
"""
from typing import Dict, List, Any, Union
from .registry import (
ServiceDefinition,
ServiceCategory,
ParameterDefinition,
ParameterType,
HardwareRequirement,
PricingTier,
PricingModel
)
SCIENTIFIC_COMPUTING_SERVICES = {
"molecular_dynamics": ServiceDefinition(
id="molecular_dynamics",
name="Molecular Dynamics Simulation",
category=ServiceCategory.SCIENTIFIC_COMPUTING,
description="Run molecular dynamics simulations using GROMACS or NAMD",
icon="🧬",
input_parameters=[
ParameterDefinition(
name="software",
type=ParameterType.ENUM,
required=True,
description="MD software package",
options=["gromacs", "namd", "amber", "lammps", "desmond"]
),
ParameterDefinition(
name="structure_file",
type=ParameterType.FILE,
required=True,
description="Molecular structure file (PDB, MOL2, etc)"
),
ParameterDefinition(
name="topology_file",
type=ParameterType.FILE,
required=False,
description="Topology file"
),
ParameterDefinition(
name="force_field",
type=ParameterType.ENUM,
required=True,
description="Force field to use",
options=["AMBER", "CHARMM", "OPLS", "GROMOS", "DREIDING"]
),
ParameterDefinition(
name="simulation_time_ns",
type=ParameterType.FLOAT,
required=True,
description="Simulation time in nanoseconds",
min_value=0.1,
max_value=1000
),
ParameterDefinition(
name="temperature_k",
type=ParameterType.FLOAT,
required=False,
description="Temperature in Kelvin",
default=300,
min_value=0,
max_value=500
),
ParameterDefinition(
name="pressure_bar",
type=ParameterType.FLOAT,
required=False,
description="Pressure in bar",
default=1,
min_value=0,
max_value=1000
),
ParameterDefinition(
name="time_step_fs",
type=ParameterType.FLOAT,
required=False,
description="Time step in femtoseconds",
default=2,
min_value=0.5,
max_value=5
)
],
output_schema={
"type": "object",
"properties": {
"trajectory_url": {"type": "string"},
"log_url": {"type": "string"},
"energy_data": {"type": "array"},
"simulation_stats": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
HardwareRequirement(component="cpu", min_value=16, recommended=64, unit="cores"),
HardwareRequirement(component="ram", min_value=32, recommended=256, unit="GB"),
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
],
pricing=[
PricingTier(name="per_ns", model=PricingModel.PER_UNIT, unit_price=0.1, min_charge=1),
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=2),
PricingTier(name="bulk_100ns", model=PricingModel.PER_UNIT, unit_price=0.05, min_charge=5)
],
capabilities=["gpu-accelerated", "parallel", "ensemble", "free-energy"],
tags=["molecular", "dynamics", "simulation", "biophysics", "chemistry"],
max_concurrent=4,
timeout_seconds=86400 # 24 hours
),
"weather_modeling": ServiceDefinition(
id="weather_modeling",
name="Weather Modeling",
category=ServiceCategory.SCIENTIFIC_COMPUTING,
description="Run weather prediction and climate simulations",
icon="🌦️",
input_parameters=[
ParameterDefinition(
name="model",
type=ParameterType.ENUM,
required=True,
description="Weather model",
options=["WRF", "MM5", "IFS", "GFS", "ECMWF"]
),
ParameterDefinition(
name="region",
type=ParameterType.OBJECT,
required=True,
description="Geographic region bounds",
properties={
"lat_min": {"type": "number"},
"lat_max": {"type": "number"},
"lon_min": {"type": "number"},
"lon_max": {"type": "number"}
}
),
ParameterDefinition(
name="forecast_hours",
type=ParameterType.INTEGER,
required=True,
description="Forecast length in hours",
min_value=1,
max_value=384 # 16 days
),
ParameterDefinition(
name="resolution_km",
type=ParameterType.FLOAT,
required=False,
description="Spatial resolution in kilometers",
default=10,
options=[1, 3, 5, 10, 25, 50]
),
ParameterDefinition(
name="output_variables",
type=ParameterType.ARRAY,
required=False,
description="Variables to output",
default=["temperature", "precipitation", "wind", "pressure"],
items={"type": "string"}
)
],
output_schema={
"type": "object",
"properties": {
"forecast_data": {"type": "array"},
"visualization_urls": {"type": "array"},
"metadata": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="cpu", min_value=32, recommended=128, unit="cores"),
HardwareRequirement(component="ram", min_value=64, recommended=512, unit="GB"),
HardwareRequirement(component="storage", min_value=500, recommended=5000, unit="GB"),
HardwareRequirement(component="network", min_value="10Gbps", recommended="100Gbps")
],
pricing=[
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=5, min_charge=10),
PricingTier(name="per_day", model=PricingModel.PER_UNIT, unit_price=100, min_charge=100),
PricingTier(name="high_res", model=PricingModel.PER_HOUR, unit_price=10, min_charge=20)
],
capabilities=["forecast", "climate", "ensemble", "data-assimilation"],
tags=["weather", "climate", "forecast", "meteorology", "atmosphere"],
max_concurrent=2,
timeout_seconds=172800 # 48 hours
),
"financial_modeling": ServiceDefinition(
id="financial_modeling",
name="Financial Modeling",
category=ServiceCategory.SCIENTIFIC_COMPUTING,
description="Run Monte Carlo simulations and risk analysis for financial models",
icon="📊",
input_parameters=[
ParameterDefinition(
name="model_type",
type=ParameterType.ENUM,
required=True,
description="Financial model type",
options=["monte-carlo", "option-pricing", "risk-var", "portfolio-optimization", "credit-risk"]
),
ParameterDefinition(
name="parameters",
type=ParameterType.OBJECT,
required=True,
description="Model parameters"
),
ParameterDefinition(
name="num_simulations",
type=ParameterType.INTEGER,
required=True,
description="Number of Monte Carlo simulations",
default=10000,
min_value=1000,
max_value=10000000
),
ParameterDefinition(
name="time_steps",
type=ParameterType.INTEGER,
required=False,
description="Number of time steps",
default=252,
min_value=1,
max_value=10000
),
ParameterDefinition(
name="confidence_levels",
type=ParameterType.ARRAY,
required=False,
description="Confidence levels for VaR",
default=[0.95, 0.99],
items={"type": "number", "minimum": 0, "maximum": 1}
)
],
output_schema={
"type": "object",
"properties": {
"results": {"type": "array"},
"statistics": {"type": "object"},
"risk_metrics": {"type": "object"},
"confidence_intervals": {"type": "array"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3080"),
HardwareRequirement(component="vram", min_value=8, recommended=16, unit="GB"),
HardwareRequirement(component="cpu", min_value=8, recommended=32, unit="cores"),
HardwareRequirement(component="ram", min_value=16, recommended=64, unit="GB")
],
pricing=[
PricingTier(name="per_simulation", model=PricingModel.PER_UNIT, unit_price=0.00001, min_charge=0.1),
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1),
PricingTier(name="enterprise", model=PricingModel.PER_UNIT, unit_price=0.000005, min_charge=0.5)
],
capabilities=["monte-carlo", "var", "option-pricing", "portfolio", "risk-analysis"],
tags=["finance", "risk", "monte-carlo", "var", "options"],
max_concurrent=10,
timeout_seconds=3600
),
"physics_simulation": ServiceDefinition(
id="physics_simulation",
name="Physics Simulation",
category=ServiceCategory.SCIENTIFIC_COMPUTING,
description="Run particle physics and fluid dynamics simulations",
icon="⚛️",
input_parameters=[
ParameterDefinition(
name="simulation_type",
type=ParameterType.ENUM,
required=True,
description="Physics simulation type",
options=["particle-physics", "fluid-dynamics", "electromagnetics", "quantum", "astrophysics"]
),
ParameterDefinition(
name="solver",
type=ParameterType.ENUM,
required=True,
description="Simulation solver",
options=["geant4", "fluent", "comsol", "openfoam", "lammps", "gadget"]
),
ParameterDefinition(
name="geometry_file",
type=ParameterType.FILE,
required=False,
description="Geometry or mesh file"
),
ParameterDefinition(
name="initial_conditions",
type=ParameterType.OBJECT,
required=True,
description="Initial conditions and parameters"
),
ParameterDefinition(
name="simulation_time",
type=ParameterType.FLOAT,
required=True,
description="Simulation time",
min_value=0.001
),
ParameterDefinition(
name="particles",
type=ParameterType.INTEGER,
required=False,
description="Number of particles",
default=1000000,
min_value=1000,
max_value=100000000
)
],
output_schema={
"type": "object",
"properties": {
"results_url": {"type": "string"},
"data_arrays": {"type": "object"},
"visualizations": {"type": "array"},
"statistics": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="a100"),
HardwareRequirement(component="vram", min_value=16, recommended=40, unit="GB"),
HardwareRequirement(component="cpu", min_value=16, recommended=64, unit="cores"),
HardwareRequirement(component="ram", min_value=32, recommended=256, unit="GB"),
HardwareRequirement(component="storage", min_value=100, recommended=1000, unit="GB")
],
pricing=[
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=2, min_charge=2),
PricingTier(name="per_particle", model=PricingModel.PER_UNIT, unit_price=0.000001, min_charge=1),
PricingTier(name="hpc", model=PricingModel.PER_HOUR, unit_price=5, min_charge=5)
],
capabilities=["gpu-accelerated", "parallel", "mpi", "large-scale"],
tags=["physics", "simulation", "particle", "fluid", "cfd"],
max_concurrent=4,
timeout_seconds=86400
),
"bioinformatics": ServiceDefinition(
id="bioinformatics",
name="Bioinformatics Analysis",
category=ServiceCategory.SCIENTIFIC_COMPUTING,
description="DNA sequencing, protein folding, and genomic analysis",
icon="🧬",
input_parameters=[
ParameterDefinition(
name="analysis_type",
type=ParameterType.ENUM,
required=True,
description="Bioinformatics analysis type",
options=["dna-sequencing", "protein-folding", "alignment", "phylogeny", "variant-calling"]
),
ParameterDefinition(
name="sequence_file",
type=ParameterType.FILE,
required=True,
description="Input sequence file (FASTA, FASTQ, BAM, etc)"
),
ParameterDefinition(
name="reference_file",
type=ParameterType.FILE,
required=False,
description="Reference genome or protein structure"
),
ParameterDefinition(
name="algorithm",
type=ParameterType.ENUM,
required=True,
description="Analysis algorithm",
options=["blast", "bowtie", "bwa", "alphafold", "gatk", "clustal"]
),
ParameterDefinition(
name="parameters",
type=ParameterType.OBJECT,
required=False,
description="Algorithm-specific parameters"
)
],
output_schema={
"type": "object",
"properties": {
"results_file": {"type": "string"},
"alignment_file": {"type": "string"},
"annotations": {"type": "array"},
"statistics": {"type": "object"}
}
},
requirements=[
HardwareRequirement(component="gpu", min_value="nvidia", recommended="rtx-3090"),
HardwareRequirement(component="vram", min_value=8, recommended=24, unit="GB"),
HardwareRequirement(component="cpu", min_value=16, recommended=32, unit="cores"),
HardwareRequirement(component="ram", min_value=32, recommended=128, unit="GB"),
HardwareRequirement(component="storage", min_value=100, recommended=500, unit="GB")
],
pricing=[
PricingTier(name="per_mb", model=PricingModel.PER_UNIT, unit_price=0.001, min_charge=0.1),
PricingTier(name="per_hour", model=PricingModel.PER_HOUR, unit_price=1, min_charge=1),
PricingTier(name="protein_folding", model=PricingModel.PER_UNIT, unit_price=0.01, min_charge=0.5)
],
capabilities=["sequencing", "alignment", "folding", "annotation", "variant-calling"],
tags=["bioinformatics", "genomics", "proteomics", "dna", "sequencing"],
max_concurrent=5,
timeout_seconds=7200
)
}

View File

@ -0,0 +1,380 @@
"""
Service schemas for common GPU workloads
"""
from typing import Any, Dict, List, Optional, Union
from enum import Enum
from pydantic import BaseModel, Field, validator
import re
class ServiceType(str, Enum):
"""Supported service types"""
WHISPER = "whisper"
STABLE_DIFFUSION = "stable_diffusion"
LLM_INFERENCE = "llm_inference"
FFMPEG = "ffmpeg"
BLENDER = "blender"
# Whisper Service Schemas
class WhisperModel(str, Enum):
"""Supported Whisper models"""
TINY = "tiny"
BASE = "base"
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
LARGE_V2 = "large-v2"
LARGE_V3 = "large-v3"
class WhisperLanguage(str, Enum):
"""Supported languages"""
AUTO = "auto"
EN = "en"
ES = "es"
FR = "fr"
DE = "de"
IT = "it"
PT = "pt"
RU = "ru"
JA = "ja"
KO = "ko"
ZH = "zh"
class WhisperTask(str, Enum):
"""Whisper task types"""
TRANSCRIBE = "transcribe"
TRANSLATE = "translate"
class WhisperRequest(BaseModel):
"""Whisper transcription request"""
audio_url: str = Field(..., description="URL of audio file to transcribe")
model: WhisperModel = Field(WhisperModel.BASE, description="Whisper model to use")
language: WhisperLanguage = Field(WhisperLanguage.AUTO, description="Source language")
task: WhisperTask = Field(WhisperTask.TRANSCRIBE, description="Task to perform")
temperature: float = Field(0.0, ge=0.0, le=1.0, description="Sampling temperature")
best_of: int = Field(5, ge=1, le=10, description="Number of candidates")
beam_size: int = Field(5, ge=1, le=10, description="Beam size for decoding")
patience: float = Field(1.0, ge=0.0, le=2.0, description="Beam search patience")
suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress")
initial_prompt: Optional[str] = Field(None, description="Initial prompt for context")
condition_on_previous_text: bool = Field(True, description="Condition on previous text")
fp16: bool = Field(True, description="Use FP16 for faster inference")
verbose: bool = Field(False, description="Include verbose output")
def get_constraints(self) -> Dict[str, Any]:
"""Get hardware constraints for this request"""
vram_requirements = {
WhisperModel.TINY: 1,
WhisperModel.BASE: 1,
WhisperModel.SMALL: 2,
WhisperModel.MEDIUM: 5,
WhisperModel.LARGE: 10,
WhisperModel.LARGE_V2: 10,
WhisperModel.LARGE_V3: 10,
}
return {
"models": ["whisper"],
"min_vram_gb": vram_requirements[self.model],
"gpu": "nvidia", # Whisper requires CUDA
}
# Stable Diffusion Service Schemas
class SDModel(str, Enum):
"""Supported Stable Diffusion models"""
SD_1_5 = "stable-diffusion-1.5"
SD_2_1 = "stable-diffusion-2.1"
SDXL = "stable-diffusion-xl"
SDXL_TURBO = "sdxl-turbo"
SDXL_REFINER = "sdxl-refiner"
class SDSize(str, Enum):
"""Standard image sizes"""
SQUARE_512 = "512x512"
PORTRAIT_512 = "512x768"
LANDSCAPE_512 = "768x512"
SQUARE_768 = "768x768"
PORTRAIT_768 = "768x1024"
LANDSCAPE_768 = "1024x768"
SQUARE_1024 = "1024x1024"
PORTRAIT_1024 = "1024x1536"
LANDSCAPE_1024 = "1536x1024"
class StableDiffusionRequest(BaseModel):
"""Stable Diffusion image generation request"""
prompt: str = Field(..., min_length=1, max_length=1000, description="Text prompt")
negative_prompt: Optional[str] = Field(None, max_length=1000, description="Negative prompt")
model: SDModel = Field(SD_1_5, description="Model to use")
size: SDSize = Field(SDSize.SQUARE_512, description="Image size")
num_images: int = Field(1, ge=1, le=4, description="Number of images to generate")
num_inference_steps: int = Field(20, ge=1, le=100, description="Number of inference steps")
guidance_scale: float = Field(7.5, ge=1.0, le=20.0, description="Guidance scale")
seed: Optional[Union[int, List[int]]] = Field(None, description="Random seed(s)")
scheduler: str = Field("DPMSolverMultistepScheduler", description="Scheduler to use")
enable_safety_checker: bool = Field(True, description="Enable safety checker")
lora: Optional[str] = Field(None, description="LoRA model to use")
lora_scale: float = Field(1.0, ge=0.0, le=2.0, description="LoRA strength")
@validator('seed')
def validate_seed(cls, v):
if v is not None and isinstance(v, list):
if len(v) > 4:
raise ValueError("Maximum 4 seeds allowed")
return v
def get_constraints(self) -> Dict[str, Any]:
"""Get hardware constraints for this request"""
vram_requirements = {
SDModel.SD_1_5: 4,
SDModel.SD_2_1: 4,
SDModel.SDXL: 8,
SDModel.SDXL_TURBO: 8,
SDModel.SDXL_REFINER: 8,
}
size_map = {
"512": 512,
"768": 768,
"1024": 1024,
"1536": 1536,
}
# Extract max dimension from size
max_dim = max(size_map[s.split('x')[0]] for s in SDSize)
return {
"models": ["stable-diffusion"],
"min_vram_gb": vram_requirements[self.model],
"gpu": "nvidia", # SD requires CUDA
"cuda": "11.8", # Minimum CUDA version
}
# LLM Inference Service Schemas
class LLMModel(str, Enum):
"""Supported LLM models"""
LLAMA_7B = "llama-7b"
LLAMA_13B = "llama-13b"
LLAMA_70B = "llama-70b"
MISTRAL_7B = "mistral-7b"
MIXTRAL_8X7B = "mixtral-8x7b"
CODELLAMA_7B = "codellama-7b"
CODELLAMA_13B = "codellama-13b"
CODELLAMA_34B = "codellama-34b"
class LLMRequest(BaseModel):
"""LLM inference request"""
model: LLMModel = Field(..., description="Model to use")
prompt: str = Field(..., min_length=1, max_length=10000, description="Input prompt")
max_tokens: int = Field(256, ge=1, le=4096, description="Maximum tokens to generate")
temperature: float = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
top_p: float = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling")
top_k: int = Field(40, ge=0, le=100, description="Top-k sampling")
repetition_penalty: float = Field(1.1, ge=0.0, le=2.0, description="Repetition penalty")
stop_sequences: Optional[List[str]] = Field(None, description="Stop sequences")
stream: bool = Field(False, description="Stream response")
def get_constraints(self) -> Dict[str, Any]:
"""Get hardware constraints for this request"""
vram_requirements = {
LLMModel.LLAMA_7B: 8,
LLMModel.LLAMA_13B: 16,
LLMModel.LLAMA_70B: 64,
LLMModel.MISTRAL_7B: 8,
LLMModel.MIXTRAL_8X7B: 48,
LLMModel.CODELLAMA_7B: 8,
LLMModel.CODELLAMA_13B: 16,
LLMModel.CODELLAMA_34B: 32,
}
return {
"models": ["llm"],
"min_vram_gb": vram_requirements[self.model],
"gpu": "nvidia", # LLMs require CUDA
"cuda": "11.8",
}
# FFmpeg Service Schemas
class FFmpegCodec(str, Enum):
"""Supported video codecs"""
H264 = "h264"
H265 = "h265"
VP9 = "vp9"
AV1 = "av1"
class FFmpegPreset(str, Enum):
"""Encoding presets"""
ULTRAFAST = "ultrafast"
SUPERFAST = "superfast"
VERYFAST = "veryfast"
FASTER = "faster"
FAST = "fast"
MEDIUM = "medium"
SLOW = "slow"
SLOWER = "slower"
VERYSLOW = "veryslow"
class FFmpegRequest(BaseModel):
"""FFmpeg video processing request"""
input_url: str = Field(..., description="URL of input video")
output_format: str = Field("mp4", description="Output format")
codec: FFmpegCodec = Field(FFmpegCodec.H264, description="Video codec")
preset: FFmpegPreset = Field(FFmpegPreset.MEDIUM, description="Encoding preset")
crf: int = Field(23, ge=0, le=51, description="Constant rate factor")
resolution: Optional[str] = Field(None, regex=r"^\d+x\d+$", description="Output resolution (e.g., 1920x1080)")
bitrate: Optional[str] = Field(None, regex=r"^\d+[kM]?$", description="Target bitrate")
fps: Optional[int] = Field(None, ge=1, le=120, description="Output frame rate")
audio_codec: str = Field("aac", description="Audio codec")
audio_bitrate: str = Field("128k", description="Audio bitrate")
custom_args: Optional[List[str]] = Field(None, description="Custom FFmpeg arguments")
def get_constraints(self) -> Dict[str, Any]:
"""Get hardware constraints for this request"""
# NVENC support for H.264/H.265
if self.codec in [FFmpegCodec.H264, FFmpegCodec.H265]:
return {
"models": ["ffmpeg"],
"gpu": "nvidia", # NVENC requires NVIDIA
"min_vram_gb": 4,
}
else:
return {
"models": ["ffmpeg"],
"gpu": "any", # CPU encoding possible
}
# Blender Service Schemas
class BlenderEngine(str, Enum):
"""Blender render engines"""
CYCLES = "cycles"
EEVEE = "eevee"
EEVEE_NEXT = "eevee-next"
class BlenderFormat(str, Enum):
"""Output formats"""
PNG = "png"
JPG = "jpg"
EXR = "exr"
BMP = "bmp"
TIFF = "tiff"
class BlenderRequest(BaseModel):
"""Blender rendering request"""
blend_file_url: str = Field(..., description="URL of .blend file")
engine: BlenderEngine = Field(BlenderEngine.CYCLES, description="Render engine")
format: BlenderFormat = Field(BlenderFormat.PNG, description="Output format")
resolution_x: int = Field(1920, ge=1, le=65536, description="Image width")
resolution_y: int = Field(1080, ge=1, le=65536, description="Image height")
resolution_percentage: int = Field(100, ge=1, le=100, description="Resolution scale")
samples: int = Field(128, ge=1, le=10000, description="Samples (Cycles only)")
frame_start: int = Field(1, ge=1, description="Start frame")
frame_end: int = Field(1, ge=1, description="End frame")
frame_step: int = Field(1, ge=1, description="Frame step")
denoise: bool = Field(True, description="Enable denoising")
transparent: bool = Field(False, description="Transparent background")
custom_args: Optional[List[str]] = Field(None, description="Custom Blender arguments")
@validator('frame_end')
def validate_frame_range(cls, v, values):
if 'frame_start' in values and v < values['frame_start']:
raise ValueError("frame_end must be >= frame_start")
return v
def get_constraints(self) -> Dict[str, Any]:
"""Get hardware constraints for this request"""
# Calculate VRAM based on resolution and samples
pixel_count = self.resolution_x * self.resolution_y
samples_multiplier = 1 if self.engine == BlenderEngine.EEVEE else self.samples / 100
estimated_vram = int((pixel_count * samples_multiplier) / (1024 * 1024))
return {
"models": ["blender"],
"min_vram_gb": max(4, estimated_vram),
"gpu": "nvidia" if self.engine == BlenderEngine.CYCLES else "any",
}
# Unified Service Request
class ServiceRequest(BaseModel):
"""Unified service request wrapper"""
service_type: ServiceType = Field(..., description="Type of service")
request_data: Dict[str, Any] = Field(..., description="Service-specific request data")
def get_service_request(self) -> Union[
WhisperRequest,
StableDiffusionRequest,
LLMRequest,
FFmpegRequest,
BlenderRequest
]:
"""Parse and return typed service request"""
service_classes = {
ServiceType.WHISPER: WhisperRequest,
ServiceType.STABLE_DIFFUSION: StableDiffusionRequest,
ServiceType.LLM_INFERENCE: LLMRequest,
ServiceType.FFMPEG: FFmpegRequest,
ServiceType.BLENDER: BlenderRequest,
}
service_class = service_classes[self.service_type]
return service_class(**self.request_data)
# Service Response Schemas
class ServiceResponse(BaseModel):
"""Base service response"""
job_id: str = Field(..., description="Job ID")
service_type: ServiceType = Field(..., description="Service type")
status: str = Field(..., description="Job status")
estimated_completion: Optional[str] = Field(None, description="Estimated completion time")
class WhisperResponse(BaseModel):
"""Whisper transcription response"""
text: str = Field(..., description="Transcribed text")
language: str = Field(..., description="Detected language")
segments: Optional[List[Dict[str, Any]]] = Field(None, description="Transcription segments")
class StableDiffusionResponse(BaseModel):
"""Stable Diffusion image generation response"""
images: List[str] = Field(..., description="Generated image URLs")
parameters: Dict[str, Any] = Field(..., description="Generation parameters")
nsfw_content_detected: List[bool] = Field(..., description="NSFW detection results")
class LLMResponse(BaseModel):
"""LLM inference response"""
text: str = Field(..., description="Generated text")
finish_reason: str = Field(..., description="Reason for generation stop")
tokens_used: int = Field(..., description="Number of tokens used")
class FFmpegResponse(BaseModel):
"""FFmpeg processing response"""
output_url: str = Field(..., description="URL of processed video")
metadata: Dict[str, Any] = Field(..., description="Video metadata")
duration: float = Field(..., description="Video duration")
class BlenderResponse(BaseModel):
"""Blender rendering response"""
images: List[str] = Field(..., description="Rendered image URLs")
metadata: Dict[str, Any] = Field(..., description="Render metadata")
render_time: float = Field(..., description="Render time in seconds")

View File

@ -0,0 +1,428 @@
"""
Repository layer for confidential transactions
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from uuid import UUID
import json
from base64 import b64encode, b64decode
from sqlalchemy import select, update, delete, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from ..models.confidential import (
ConfidentialTransactionDB,
ParticipantKeyDB,
ConfidentialAccessLogDB,
KeyRotationLogDB,
AuditAuthorizationDB
)
from ..models import (
ConfidentialTransaction,
KeyPair,
ConfidentialAccessLog,
KeyRotationLog,
AuditAuthorization
)
from ..database import get_async_session
class ConfidentialTransactionRepository:
"""Repository for confidential transaction operations"""
async def create(
self,
session: AsyncSession,
transaction: ConfidentialTransaction
) -> ConfidentialTransactionDB:
"""Create a new confidential transaction"""
db_transaction = ConfidentialTransactionDB(
transaction_id=transaction.transaction_id,
job_id=transaction.job_id,
status=transaction.status,
confidential=transaction.confidential,
algorithm=transaction.algorithm,
encrypted_data=b64decode(transaction.encrypted_data) if transaction.encrypted_data else None,
encrypted_keys=transaction.encrypted_keys,
participants=transaction.participants,
access_policies=transaction.access_policies,
created_by=transaction.participants[0] if transaction.participants else None
)
session.add(db_transaction)
await session.commit()
await session.refresh(db_transaction)
return db_transaction
async def get_by_id(
self,
session: AsyncSession,
transaction_id: str
) -> Optional[ConfidentialTransactionDB]:
"""Get transaction by ID"""
stmt = select(ConfidentialTransactionDB).where(
ConfidentialTransactionDB.transaction_id == transaction_id
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_job_id(
self,
session: AsyncSession,
job_id: str
) -> Optional[ConfidentialTransactionDB]:
"""Get transaction by job ID"""
stmt = select(ConfidentialTransactionDB).where(
ConfidentialTransactionDB.job_id == job_id
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_by_participant(
self,
session: AsyncSession,
participant_id: str,
limit: int = 100,
offset: int = 0
) -> List[ConfidentialTransactionDB]:
"""List transactions for a participant"""
stmt = select(ConfidentialTransactionDB).where(
ConfidentialTransactionDB.participants.contains([participant_id])
).offset(offset).limit(limit)
result = await session.execute(stmt)
return result.scalars().all()
async def update_status(
self,
session: AsyncSession,
transaction_id: str,
status: str
) -> bool:
"""Update transaction status"""
stmt = update(ConfidentialTransactionDB).where(
ConfidentialTransactionDB.transaction_id == transaction_id
).values(status=status)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
async def delete(
self,
session: AsyncSession,
transaction_id: str
) -> bool:
"""Delete a transaction"""
stmt = delete(ConfidentialTransactionDB).where(
ConfidentialTransactionDB.transaction_id == transaction_id
)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
class ParticipantKeyRepository:
"""Repository for participant key operations"""
async def create(
self,
session: AsyncSession,
key_pair: KeyPair
) -> ParticipantKeyDB:
"""Store a new key pair"""
# In production, private_key should be encrypted with master key
db_key = ParticipantKeyDB(
participant_id=key_pair.participant_id,
encrypted_private_key=key_pair.private_key,
public_key=key_pair.public_key,
algorithm=key_pair.algorithm,
version=key_pair.version,
active=True
)
session.add(db_key)
await session.commit()
await session.refresh(db_key)
return db_key
async def get_by_participant(
self,
session: AsyncSession,
participant_id: str,
active_only: bool = True
) -> Optional[ParticipantKeyDB]:
"""Get key pair for participant"""
stmt = select(ParticipantKeyDB).where(
ParticipantKeyDB.participant_id == participant_id
)
if active_only:
stmt = stmt.where(ParticipantKeyDB.active == True)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def update_active(
self,
session: AsyncSession,
participant_id: str,
active: bool,
reason: Optional[str] = None
) -> bool:
"""Update key active status"""
stmt = update(ParticipantKeyDB).where(
ParticipantKeyDB.participant_id == participant_id
).values(
active=active,
revoked_at=datetime.utcnow() if not active else None,
revoke_reason=reason
)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
async def rotate(
self,
session: AsyncSession,
participant_id: str,
new_key_pair: KeyPair
) -> ParticipantKeyDB:
"""Rotate to new key pair"""
# Deactivate old key
await self.update_active(session, participant_id, False, "rotation")
# Store new key
return await self.create(session, new_key_pair)
async def list_active(
self,
session: AsyncSession,
limit: int = 100,
offset: int = 0
) -> List[ParticipantKeyDB]:
"""List active keys"""
stmt = select(ParticipantKeyDB).where(
ParticipantKeyDB.active == True
).offset(offset).limit(limit)
result = await session.execute(stmt)
return result.scalars().all()
class AccessLogRepository:
"""Repository for access log operations"""
async def create(
self,
session: AsyncSession,
log: ConfidentialAccessLog
) -> ConfidentialAccessLogDB:
"""Create access log entry"""
db_log = ConfidentialAccessLogDB(
transaction_id=log.transaction_id,
participant_id=log.participant_id,
purpose=log.purpose,
action=log.action,
resource=log.resource,
outcome=log.outcome,
details=log.details,
data_accessed=log.data_accessed,
ip_address=log.ip_address,
user_agent=log.user_agent,
authorization_id=log.authorized_by,
signature=log.signature
)
session.add(db_log)
await session.commit()
await session.refresh(db_log)
return db_log
async def query(
self,
session: AsyncSession,
transaction_id: Optional[str] = None,
participant_id: Optional[str] = None,
purpose: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 100,
offset: int = 0
) -> List[ConfidentialAccessLogDB]:
"""Query access logs"""
stmt = select(ConfidentialAccessLogDB)
# Build filters
filters = []
if transaction_id:
filters.append(ConfidentialAccessLogDB.transaction_id == transaction_id)
if participant_id:
filters.append(ConfidentialAccessLogDB.participant_id == participant_id)
if purpose:
filters.append(ConfidentialAccessLogDB.purpose == purpose)
if start_time:
filters.append(ConfidentialAccessLogDB.timestamp >= start_time)
if end_time:
filters.append(ConfidentialAccessLogDB.timestamp <= end_time)
if filters:
stmt = stmt.where(and_(*filters))
# Order by timestamp descending
stmt = stmt.order_by(ConfidentialAccessLogDB.timestamp.desc())
stmt = stmt.offset(offset).limit(limit)
result = await session.execute(stmt)
return result.scalars().all()
async def count(
self,
session: AsyncSession,
transaction_id: Optional[str] = None,
participant_id: Optional[str] = None,
purpose: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None
) -> int:
"""Count access logs matching criteria"""
stmt = select(ConfidentialAccessLogDB)
# Build filters
filters = []
if transaction_id:
filters.append(ConfidentialAccessLogDB.transaction_id == transaction_id)
if participant_id:
filters.append(ConfidentialAccessLogDB.participant_id == participant_id)
if purpose:
filters.append(ConfidentialAccessLogDB.purpose == purpose)
if start_time:
filters.append(ConfidentialAccessLogDB.timestamp >= start_time)
if end_time:
filters.append(ConfidentialAccessLogDB.timestamp <= end_time)
if filters:
stmt = stmt.where(and_(*filters))
result = await session.execute(stmt)
return len(result.all())
class KeyRotationRepository:
"""Repository for key rotation logs"""
async def create(
self,
session: AsyncSession,
log: KeyRotationLog
) -> KeyRotationLogDB:
"""Create key rotation log"""
db_log = KeyRotationLogDB(
participant_id=log.participant_id,
old_version=log.old_version,
new_version=log.new_version,
rotated_at=log.rotated_at,
reason=log.reason
)
session.add(db_log)
await session.commit()
await session.refresh(db_log)
return db_log
async def list_by_participant(
self,
session: AsyncSession,
participant_id: str,
limit: int = 50
) -> List[KeyRotationLogDB]:
"""List rotation logs for participant"""
stmt = select(KeyRotationLogDB).where(
KeyRotationLogDB.participant_id == participant_id
).order_by(KeyRotationLogDB.rotated_at.desc()).limit(limit)
result = await session.execute(stmt)
return result.scalars().all()
class AuditAuthorizationRepository:
"""Repository for audit authorizations"""
async def create(
self,
session: AsyncSession,
auth: AuditAuthorization
) -> AuditAuthorizationDB:
"""Create audit authorization"""
db_auth = AuditAuthorizationDB(
issuer=auth.issuer,
subject=auth.subject,
purpose=auth.purpose,
created_at=auth.created_at,
expires_at=auth.expires_at,
signature=auth.signature,
metadata=auth.__dict__
)
session.add(db_auth)
await session.commit()
await session.refresh(db_auth)
return db_auth
async def get_valid(
self,
session: AsyncSession,
authorization_id: str
) -> Optional[AuditAuthorizationDB]:
"""Get valid authorization"""
stmt = select(AuditAuthorizationDB).where(
and_(
AuditAuthorizationDB.id == authorization_id,
AuditAuthorizationDB.active == True,
AuditAuthorizationDB.expires_at > datetime.utcnow()
)
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def revoke(
self,
session: AsyncSession,
authorization_id: str
) -> bool:
"""Revoke authorization"""
stmt = update(AuditAuthorizationDB).where(
AuditAuthorizationDB.id == authorization_id
).values(active=False, revoked_at=datetime.utcnow())
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
async def cleanup_expired(
self,
session: AsyncSession
) -> int:
"""Clean up expired authorizations"""
stmt = update(AuditAuthorizationDB).where(
AuditAuthorizationDB.expires_at < datetime.utcnow()
).values(active=False)
result = await session.execute(stmt)
await session.commit()
return result.rowcount

View File

@ -5,5 +5,7 @@ from .miner import router as miner
from .admin import router as admin
from .marketplace import router as marketplace
from .explorer import router as explorer
from .services import router as services
from .registry import router as registry
__all__ = ["client", "miner", "admin", "marketplace", "explorer"]
__all__ = ["client", "miner", "admin", "marketplace", "explorer", "services", "registry"]

View File

@ -0,0 +1,423 @@
"""
API endpoints for confidential transactions
"""
from typing import Optional, List
from datetime import datetime
from fastapi import APIRouter, HTTPException, Depends, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import json
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..models import (
ConfidentialTransaction,
ConfidentialTransactionCreate,
ConfidentialTransactionView,
ConfidentialAccessRequest,
ConfidentialAccessResponse,
KeyRegistrationRequest,
KeyRegistrationResponse,
AccessLogQuery,
AccessLogResponse
)
from ..services.encryption import EncryptionService, EncryptedData
from ..services.key_management import KeyManager, KeyManagementError
from ..services.access_control import AccessController
from ..auth import get_api_key
from ..logging import get_logger
logger = get_logger(__name__)
# Initialize router and security
router = APIRouter(prefix="/confidential", tags=["confidential"])
security = HTTPBearer()
limiter = Limiter(key_func=get_remote_address)
# Global instances (in production, inject via DI)
encryption_service: Optional[EncryptionService] = None
key_manager: Optional[KeyManager] = None
access_controller: Optional[AccessController] = None
def get_encryption_service() -> EncryptionService:
"""Get encryption service instance"""
global encryption_service
if encryption_service is None:
# Initialize with key manager
from ..services.key_management import FileKeyStorage
key_storage = FileKeyStorage("/tmp/aitbc_keys")
key_manager = KeyManager(key_storage)
encryption_service = EncryptionService(key_manager)
return encryption_service
def get_key_manager() -> KeyManager:
"""Get key manager instance"""
global key_manager
if key_manager is None:
from ..services.key_management import FileKeyStorage
key_storage = FileKeyStorage("/tmp/aitbc_keys")
key_manager = KeyManager(key_storage)
return key_manager
def get_access_controller() -> AccessController:
"""Get access controller instance"""
global access_controller
if access_controller is None:
from ..services.access_control import PolicyStore
policy_store = PolicyStore()
access_controller = AccessController(policy_store)
return access_controller
@router.post("/transactions", response_model=ConfidentialTransactionView)
async def create_confidential_transaction(
request: ConfidentialTransactionCreate,
api_key: str = Depends(get_api_key)
):
"""Create a new confidential transaction with optional encryption"""
try:
# Generate transaction ID
transaction_id = f"ctx-{datetime.utcnow().timestamp()}"
# Create base transaction
transaction = ConfidentialTransaction(
transaction_id=transaction_id,
job_id=request.job_id,
timestamp=datetime.utcnow(),
status="created",
amount=request.amount,
pricing=request.pricing,
settlement_details=request.settlement_details,
confidential=request.confidential,
participants=request.participants,
access_policies=request.access_policies
)
# Encrypt sensitive data if requested
if request.confidential and request.participants:
# Prepare data for encryption
sensitive_data = {
"amount": request.amount,
"pricing": request.pricing,
"settlement_details": request.settlement_details
}
# Remove None values
sensitive_data = {k: v for k, v in sensitive_data.items() if v is not None}
if sensitive_data:
# Encrypt data
enc_service = get_encryption_service()
encrypted = enc_service.encrypt(
data=sensitive_data,
participants=request.participants,
include_audit=True
)
# Update transaction with encrypted data
transaction.encrypted_data = encrypted.to_dict()["ciphertext"]
transaction.encrypted_keys = encrypted.to_dict()["encrypted_keys"]
transaction.algorithm = encrypted.algorithm
# Clear plaintext fields
transaction.amount = None
transaction.pricing = None
transaction.settlement_details = None
# Store transaction (in production, save to database)
logger.info(f"Created confidential transaction: {transaction_id}")
# Return view
return ConfidentialTransactionView(
transaction_id=transaction.transaction_id,
job_id=transaction.job_id,
timestamp=transaction.timestamp,
status=transaction.status,
amount=transaction.amount, # Will be None if encrypted
pricing=transaction.pricing,
settlement_details=transaction.settlement_details,
confidential=transaction.confidential,
participants=transaction.participants,
has_encrypted_data=transaction.encrypted_data is not None
)
except Exception as e:
logger.error(f"Failed to create confidential transaction: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/transactions/{transaction_id}", response_model=ConfidentialTransactionView)
async def get_confidential_transaction(
transaction_id: str,
api_key: str = Depends(get_api_key)
):
"""Get confidential transaction metadata (without decrypting sensitive data)"""
try:
# Retrieve transaction (in production, query from database)
# For now, return error as we don't have storage
raise HTTPException(status_code=404, detail="Transaction not found")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get transaction {transaction_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/transactions/{transaction_id}/access", response_model=ConfidentialAccessResponse)
@limiter.limit("10/minute") # Rate limit decryption requests
async def access_confidential_data(
request: ConfidentialAccessRequest,
transaction_id: str,
api_key: str = Depends(get_api_key)
):
"""Request access to decrypt confidential transaction data"""
try:
# Validate request
if request.transaction_id != transaction_id:
raise HTTPException(status_code=400, detail="Transaction ID mismatch")
# Get transaction (in production, query from database)
# For now, create mock transaction
transaction = ConfidentialTransaction(
transaction_id=transaction_id,
job_id="test-job",
timestamp=datetime.utcnow(),
status="completed",
confidential=True,
participants=["client-456", "miner-789"]
)
if not transaction.confidential:
raise HTTPException(status_code=400, detail="Transaction is not confidential")
# Check access authorization
acc_controller = get_access_controller()
if not acc_controller.verify_access(request):
raise HTTPException(status_code=403, detail="Access denied")
# Decrypt data
enc_service = get_encryption_service()
# Reconstruct encrypted data
if not transaction.encrypted_data or not transaction.encrypted_keys:
raise HTTPException(status_code=404, detail="No encrypted data found")
encrypted_data = EncryptedData.from_dict({
"ciphertext": transaction.encrypted_data,
"encrypted_keys": transaction.encrypted_keys,
"algorithm": transaction.algorithm or "AES-256-GCM+X25519"
})
# Decrypt for requester
try:
decrypted_data = enc_service.decrypt(
encrypted_data=encrypted_data,
participant_id=request.requester,
purpose=request.purpose
)
return ConfidentialAccessResponse(
success=True,
data=decrypted_data,
access_id=f"access-{datetime.utcnow().timestamp()}"
)
except Exception as e:
logger.error(f"Decryption failed: {e}")
return ConfidentialAccessResponse(
success=False,
error=str(e)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to access confidential data: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/transactions/{transaction_id}/audit", response_model=ConfidentialAccessResponse)
async def audit_access_confidential_data(
transaction_id: str,
authorization: str,
purpose: str = "compliance",
api_key: str = Depends(get_api_key)
):
"""Audit access to confidential transaction data"""
try:
# Get transaction
transaction = ConfidentialTransaction(
transaction_id=transaction_id,
job_id="test-job",
timestamp=datetime.utcnow(),
status="completed",
confidential=True
)
if not transaction.confidential:
raise HTTPException(status_code=400, detail="Transaction is not confidential")
# Decrypt with audit key
enc_service = get_encryption_service()
if not transaction.encrypted_data or not transaction.encrypted_keys:
raise HTTPException(status_code=404, detail="No encrypted data found")
encrypted_data = EncryptedData.from_dict({
"ciphertext": transaction.encrypted_data,
"encrypted_keys": transaction.encrypted_keys,
"algorithm": transaction.algorithm or "AES-256-GCM+X25519"
})
# Decrypt for audit
try:
decrypted_data = enc_service.audit_decrypt(
encrypted_data=encrypted_data,
audit_authorization=authorization,
purpose=purpose
)
return ConfidentialAccessResponse(
success=True,
data=decrypted_data,
access_id=f"audit-{datetime.utcnow().timestamp()}"
)
except Exception as e:
logger.error(f"Audit decryption failed: {e}")
return ConfidentialAccessResponse(
success=False,
error=str(e)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed audit access: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/keys/register", response_model=KeyRegistrationResponse)
async def register_encryption_key(
request: KeyRegistrationRequest,
api_key: str = Depends(get_api_key)
):
"""Register public key for confidential transactions"""
try:
# Get key manager
km = get_key_manager()
# Check if participant already has keys
try:
existing_key = km.get_public_key(request.participant_id)
if existing_key:
# Key exists, return version
return KeyRegistrationResponse(
success=True,
participant_id=request.participant_id,
key_version=1, # Would get from storage
registered_at=datetime.utcnow(),
error=None
)
except:
pass # Key doesn't exist, continue
# Generate new key pair
key_pair = await km.generate_key_pair(request.participant_id)
return KeyRegistrationResponse(
success=True,
participant_id=request.participant_id,
key_version=key_pair.version,
registered_at=key_pair.created_at,
error=None
)
except KeyManagementError as e:
logger.error(f"Key registration failed: {e}")
return KeyRegistrationResponse(
success=False,
participant_id=request.participant_id,
key_version=0,
registered_at=datetime.utcnow(),
error=str(e)
)
except Exception as e:
logger.error(f"Failed to register key: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/keys/rotate")
async def rotate_encryption_key(
participant_id: str,
api_key: str = Depends(get_api_key)
):
"""Rotate encryption keys for participant"""
try:
km = get_key_manager()
# Rotate keys
new_key_pair = await km.rotate_keys(participant_id)
return {
"success": True,
"participant_id": participant_id,
"new_version": new_key_pair.version,
"rotated_at": new_key_pair.created_at
}
except KeyManagementError as e:
logger.error(f"Key rotation failed: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to rotate keys: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/access/logs", response_model=AccessLogResponse)
async def get_access_logs(
query: AccessLogQuery = Depends(),
api_key: str = Depends(get_api_key)
):
"""Get access logs for confidential transactions"""
try:
# Query logs (in production, query from database)
# For now, return empty response
return AccessLogResponse(
logs=[],
total_count=0,
has_more=False
)
except Exception as e:
logger.error(f"Failed to get access logs: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status")
async def get_confidential_status(
api_key: str = Depends(get_api_key)
):
"""Get status of confidential transaction system"""
try:
km = get_key_manager()
enc_service = get_encryption_service()
# Get system status
participants = await km.list_participants()
return {
"enabled": True,
"algorithm": "AES-256-GCM+X25519",
"participants_count": len(participants),
"transactions_count": 0, # Would query from database
"audit_enabled": True
}
except Exception as e:
logger.error(f"Failed to get status: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@ -6,6 +6,7 @@ from fastapi import status as http_status
from ..models import MarketplaceBidRequest, MarketplaceOfferView, MarketplaceStatsView
from ..services import MarketplaceService
from ..storage import SessionDep
from ..metrics import marketplace_requests_total, marketplace_errors_total
router = APIRouter(tags=["marketplace"])
@ -26,11 +27,16 @@ async def list_marketplace_offers(
limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0),
) -> list[MarketplaceOfferView]:
marketplace_requests_total.labels(endpoint="/marketplace/offers", method="GET").inc()
service = _get_service(session)
try:
return service.list_offers(status=status_filter, limit=limit, offset=offset)
except ValueError:
marketplace_errors_total.labels(endpoint="/marketplace/offers", method="GET", error_type="invalid_request").inc()
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail="invalid status filter") from None
except Exception:
marketplace_errors_total.labels(endpoint="/marketplace/offers", method="GET", error_type="internal").inc()
raise
@router.get(
@ -39,8 +45,13 @@ async def list_marketplace_offers(
summary="Get marketplace summary statistics",
)
async def get_marketplace_stats(*, session: SessionDep) -> MarketplaceStatsView:
marketplace_requests_total.labels(endpoint="/marketplace/stats", method="GET").inc()
service = _get_service(session)
return service.get_stats()
try:
return service.get_stats()
except Exception:
marketplace_errors_total.labels(endpoint="/marketplace/stats", method="GET", error_type="internal").inc()
raise
@router.post(
@ -52,6 +63,14 @@ async def submit_marketplace_bid(
payload: MarketplaceBidRequest,
session: SessionDep,
) -> dict[str, str]:
marketplace_requests_total.labels(endpoint="/marketplace/bids", method="POST").inc()
service = _get_service(session)
bid = service.create_bid(payload)
return {"id": bid.id}
try:
bid = service.create_bid(payload)
return {"id": bid.id}
except ValueError:
marketplace_errors_total.labels(endpoint="/marketplace/bids", method="POST", error_type="invalid_request").inc()
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail="invalid bid data") from None
except Exception:
marketplace_errors_total.labels(endpoint="/marketplace/bids", method="POST", error_type="internal").inc()
raise

View File

@ -0,0 +1,303 @@
"""
Service registry router for dynamic service management
"""
from typing import Dict, List, Any, Optional
from fastapi import APIRouter, HTTPException, status
from ..models.registry import (
ServiceRegistry,
ServiceDefinition,
ServiceCategory
)
from ..models.registry_media import MEDIA_PROCESSING_SERVICES
from ..models.registry_scientific import SCIENTIFIC_COMPUTING_SERVICES
from ..models.registry_data import DATA_ANALYTICS_SERVICES
from ..models.registry_gaming import GAMING_SERVICES
from ..models.registry_devtools import DEVTOOLS_SERVICES
from ..models.registry import AI_ML_SERVICES
router = APIRouter(prefix="/registry", tags=["service-registry"])
# Initialize service registry with all services
def create_service_registry() -> ServiceRegistry:
"""Create and populate the service registry"""
all_services = {}
# Add all service categories
all_services.update(AI_ML_SERVICES)
all_services.update(MEDIA_PROCESSING_SERVICES)
all_services.update(SCIENTIFIC_COMPUTING_SERVICES)
all_services.update(DATA_ANALYTICS_SERVICES)
all_services.update(GAMING_SERVICES)
all_services.update(DEVTOOLS_SERVICES)
return ServiceRegistry(
version="1.0.0",
services=all_services
)
# Global registry instance
service_registry = create_service_registry()
@router.get("/", response_model=ServiceRegistry)
async def get_registry() -> ServiceRegistry:
"""Get the complete service registry"""
return service_registry
@router.get("/services", response_model=List[ServiceDefinition])
async def list_services(
category: Optional[ServiceCategory] = None,
search: Optional[str] = None
) -> List[ServiceDefinition]:
"""List all available services with optional filtering"""
services = list(service_registry.services.values())
# Filter by category
if category:
services = [s for s in services if s.category == category]
# Search by name, description, or tags
if search:
search = search.lower()
services = [
s for s in services
if (search in s.name.lower() or
search in s.description.lower() or
any(search in tag.lower() for tag in s.tags))
]
return services
@router.get("/services/{service_id}", response_model=ServiceDefinition)
async def get_service(service_id: str) -> ServiceDefinition:
"""Get a specific service definition"""
service = service_registry.get_service(service_id)
if not service:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service {service_id} not found"
)
return service
@router.get("/categories", response_model=List[Dict[str, Any]])
async def list_categories() -> List[Dict[str, Any]]:
"""List all service categories with counts"""
category_counts = {}
for service in service_registry.services.values():
category = service.category.value
if category not in category_counts:
category_counts[category] = 0
category_counts[category] += 1
return [
{"category": cat, "count": count}
for cat, count in category_counts.items()
]
@router.get("/categories/{category}", response_model=List[ServiceDefinition])
async def get_services_by_category(category: ServiceCategory) -> List[ServiceDefinition]:
"""Get all services in a specific category"""
return service_registry.get_services_by_category(category)
@router.get("/services/{service_id}/schema")
async def get_service_schema(service_id: str) -> Dict[str, Any]:
"""Get JSON schema for a service's input parameters"""
service = service_registry.get_service(service_id)
if not service:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service {service_id} not found"
)
# Convert input parameters to JSON schema
properties = {}
required = []
for param in service.input_parameters:
prop = {
"type": param.type.value,
"description": param.description
}
if param.default is not None:
prop["default"] = param.default
if param.min_value is not None:
prop["minimum"] = param.min_value
if param.max_value is not None:
prop["maximum"] = param.max_value
if param.options:
prop["enum"] = param.options
if param.validation:
prop.update(param.validation)
properties[param.name] = prop
if param.required:
required.append(param.name)
return {
"type": "object",
"properties": properties,
"required": required
}
@router.get("/services/{service_id}/requirements")
async def get_service_requirements(service_id: str) -> Dict[str, Any]:
"""Get hardware requirements for a service"""
service = service_registry.get_service(service_id)
if not service:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service {service_id} not found"
)
return {
"requirements": [
{
"component": req.component,
"minimum": req.min_value,
"recommended": req.recommended,
"unit": req.unit
}
for req in service.requirements
]
}
@router.get("/services/{service_id}/pricing")
async def get_service_pricing(service_id: str) -> Dict[str, Any]:
"""Get pricing information for a service"""
service = service_registry.get_service(service_id)
if not service:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service {service_id} not found"
)
return {
"pricing": [
{
"tier": tier.name,
"model": tier.model.value,
"unit_price": tier.unit_price,
"min_charge": tier.min_charge,
"currency": tier.currency,
"description": tier.description
}
for tier in service.pricing
]
}
@router.post("/services/validate")
async def validate_service_request(
service_id: str,
request_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Validate a service request against the service schema"""
service = service_registry.get_service(service_id)
if not service:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service {service_id} not found"
)
# Validate request data
validation_result = {
"valid": True,
"errors": [],
"warnings": []
}
# Check required parameters
provided_params = set(request_data.keys())
required_params = {p.name for p in service.input_parameters if p.required}
missing_params = required_params - provided_params
if missing_params:
validation_result["valid"] = False
validation_result["errors"].extend([
f"Missing required parameter: {param}"
for param in missing_params
])
# Validate parameter types and constraints
for param in service.input_parameters:
if param.name in request_data:
value = request_data[param.name]
# Type validation (simplified)
if param.type == "integer" and not isinstance(value, int):
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be an integer"
)
elif param.type == "float" and not isinstance(value, (int, float)):
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be a number"
)
elif param.type == "boolean" and not isinstance(value, bool):
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be a boolean"
)
elif param.type == "array" and not isinstance(value, list):
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be an array"
)
# Value constraints
if param.min_value is not None and value < param.min_value:
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be >= {param.min_value}"
)
if param.max_value is not None and value > param.max_value:
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be <= {param.max_value}"
)
# Enum options
if param.options and value not in param.options:
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be one of: {', '.join(param.options)}"
)
return validation_result
@router.get("/stats")
async def get_registry_stats() -> Dict[str, Any]:
"""Get registry statistics"""
total_services = len(service_registry.services)
category_counts = {}
for service in service_registry.services.values():
category = service.category.value
if category not in category_counts:
category_counts[category] = 0
category_counts[category] += 1
# Count unique pricing models
pricing_models = set()
for service in service_registry.services.values():
for tier in service.pricing:
pricing_models.add(tier.model.value)
return {
"total_services": total_services,
"categories": category_counts,
"pricing_models": list(pricing_models),
"last_updated": service_registry.last_updated.isoformat()
}

View File

@ -0,0 +1,612 @@
"""
Services router for specific GPU workloads
"""
from typing import Any, Dict, Union
from fastapi import APIRouter, Depends, HTTPException, status, Header
from fastapi.responses import StreamingResponse
from ..deps import require_client_key
from ..models import JobCreate, JobView, JobResult
from ..models.services import (
ServiceType,
ServiceRequest,
ServiceResponse,
WhisperRequest,
StableDiffusionRequest,
LLMRequest,
FFmpegRequest,
BlenderRequest,
)
from ..models.registry import ServiceRegistry, service_registry
from ..services import JobService
from ..storage import SessionDep
router = APIRouter(tags=["services"])
@router.post(
"/services/{service_type}",
response_model=ServiceResponse,
status_code=status.HTTP_201_CREATED,
summary="Submit a service-specific job",
deprecated=True
)
async def submit_service_job(
service_type: ServiceType,
request_data: Dict[str, Any],
session: SessionDep,
client_id: str = Depends(require_client_key()),
user_agent: str = Header(None),
) -> ServiceResponse:
"""Submit a job for a specific service type
DEPRECATED: Use /v1/registry/services/{service_id} endpoint instead.
This endpoint will be removed in version 2.0.
"""
# Add deprecation warning header
from fastapi import Response
response = Response()
response.headers["X-Deprecated"] = "true"
response.headers["X-Deprecation-Message"] = "Use /v1/registry/services/{service_id} instead"
# Check if service exists in registry
service = service_registry.get_service(service_type.value)
if not service:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service {service_type} not found"
)
# Validate request against service schema
validation_result = await validate_service_request(service_type.value, request_data)
if not validation_result["valid"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid request: {', '.join(validation_result['errors'])}"
)
# Create service request wrapper
service_request = ServiceRequest(
service_type=service_type,
request_data=request_data
)
# Validate and parse service-specific request
try:
typed_request = service_request.get_service_request()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid request for {service_type}: {str(e)}"
)
# Get constraints from service request
constraints = typed_request.get_constraints()
# Create job with service-specific payload
job_payload = {
"service_type": service_type.value,
"service_request": request_data,
}
job_create = JobCreate(
payload=job_payload,
constraints=constraints,
ttl_seconds=900 # Default 15 minutes
)
# Submit job
service = JobService(session)
job = service.create_job(client_id, job_create)
return ServiceResponse(
job_id=job.job_id,
service_type=service_type,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
# Whisper endpoints
@router.post(
"/services/whisper/transcribe",
response_model=ServiceResponse,
status_code=status.HTTP_201_CREATED,
summary="Transcribe audio using Whisper"
)
async def whisper_transcribe(
request: WhisperRequest,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> ServiceResponse:
"""Transcribe audio file using Whisper"""
job_payload = {
"service_type": ServiceType.WHISPER.value,
"service_request": request.dict(),
}
job_create = JobCreate(
payload=job_payload,
constraints=request.get_constraints(),
ttl_seconds=900
)
service = JobService(session)
job = service.create_job(client_id, job_create)
return ServiceResponse(
job_id=job.job_id,
service_type=ServiceType.WHISPER,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
@router.post(
"/services/whisper/translate",
response_model=ServiceResponse,
status_code=status.HTTP_201_CREATED,
summary="Translate audio using Whisper"
)
async def whisper_translate(
request: WhisperRequest,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> ServiceResponse:
"""Translate audio file using Whisper"""
# Force task to be translate
request.task = "translate"
job_payload = {
"service_type": ServiceType.WHISPER.value,
"service_request": request.dict(),
}
job_create = JobCreate(
payload=job_payload,
constraints=request.get_constraints(),
ttl_seconds=900
)
service = JobService(session)
job = service.create_job(client_id, job_create)
return ServiceResponse(
job_id=job.job_id,
service_type=ServiceType.WHISPER,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
# Stable Diffusion endpoints
@router.post(
"/services/stable-diffusion/generate",
response_model=ServiceResponse,
status_code=status.HTTP_201_CREATED,
summary="Generate images using Stable Diffusion"
)
async def stable_diffusion_generate(
request: StableDiffusionRequest,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> ServiceResponse:
"""Generate images using Stable Diffusion"""
job_payload = {
"service_type": ServiceType.STABLE_DIFFUSION.value,
"service_request": request.dict(),
}
job_create = JobCreate(
payload=job_payload,
constraints=request.get_constraints(),
ttl_seconds=600 # 10 minutes for image generation
)
service = JobService(session)
job = service.create_job(client_id, job_create)
return ServiceResponse(
job_id=job.job_id,
service_type=ServiceType.STABLE_DIFFUSION,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
@router.post(
"/services/stable-diffusion/img2img",
response_model=ServiceResponse,
status_code=status.HTTP_201_CREATED,
summary="Image-to-image generation"
)
async def stable_diffusion_img2img(
request: StableDiffusionRequest,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> ServiceResponse:
"""Image-to-image generation using Stable Diffusion"""
# Add img2img specific parameters
request_data = request.dict()
request_data["mode"] = "img2img"
job_payload = {
"service_type": ServiceType.STABLE_DIFFUSION.value,
"service_request": request_data,
}
job_create = JobCreate(
payload=job_payload,
constraints=request.get_constraints(),
ttl_seconds=600
)
service = JobService(session)
job = service.create_job(client_id, job_create)
return ServiceResponse(
job_id=job.job_id,
service_type=ServiceType.STABLE_DIFFUSION,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
# LLM Inference endpoints
@router.post(
"/services/llm/inference",
response_model=ServiceResponse,
status_code=status.HTTP_201_CREATED,
summary="Run LLM inference"
)
async def llm_inference(
request: LLMRequest,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> ServiceResponse:
"""Run inference on a language model"""
job_payload = {
"service_type": ServiceType.LLM_INFERENCE.value,
"service_request": request.dict(),
}
job_create = JobCreate(
payload=job_payload,
constraints=request.get_constraints(),
ttl_seconds=300 # 5 minutes for text generation
)
service = JobService(session)
job = service.create_job(client_id, job_create)
return ServiceResponse(
job_id=job.job_id,
service_type=ServiceType.LLM_INFERENCE,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
@router.post(
"/services/llm/stream",
summary="Stream LLM inference"
)
async def llm_stream(
request: LLMRequest,
session: SessionDep,
client_id: str = Depends(require_client_key()),
):
"""Stream LLM inference response"""
# Force streaming mode
request.stream = True
job_payload = {
"service_type": ServiceType.LLM_INFERENCE.value,
"service_request": request.dict(),
}
job_create = JobCreate(
payload=job_payload,
constraints=request.get_constraints(),
ttl_seconds=300
)
service = JobService(session)
job = service.create_job(client_id, job_create)
# Return streaming response
# This would implement WebSocket or Server-Sent Events
return ServiceResponse(
job_id=job.job_id,
service_type=ServiceType.LLM_INFERENCE,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
# FFmpeg endpoints
@router.post(
"/services/ffmpeg/transcode",
response_model=ServiceResponse,
status_code=status.HTTP_201_CREATED,
summary="Transcode video using FFmpeg"
)
async def ffmpeg_transcode(
request: FFmpegRequest,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> ServiceResponse:
"""Transcode video using FFmpeg"""
job_payload = {
"service_type": ServiceType.FFMPEG.value,
"service_request": request.dict(),
}
# Adjust TTL based on video length (would need to probe video)
job_create = JobCreate(
payload=job_payload,
constraints=request.get_constraints(),
ttl_seconds=1800 # 30 minutes for video transcoding
)
service = JobService(session)
job = service.create_job(client_id, job_create)
return ServiceResponse(
job_id=job.job_id,
service_type=ServiceType.FFMPEG,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
# Blender endpoints
@router.post(
"/services/blender/render",
response_model=ServiceResponse,
status_code=status.HTTP_201_CREATED,
summary="Render using Blender"
)
async def blender_render(
request: BlenderRequest,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> ServiceResponse:
"""Render scene using Blender"""
job_payload = {
"service_type": ServiceType.BLENDER.value,
"service_request": request.dict(),
}
# Adjust TTL based on frame count
frame_count = request.frame_end - request.frame_start + 1
estimated_time = frame_count * 30 # 30 seconds per frame estimate
ttl_seconds = max(600, estimated_time) # Minimum 10 minutes
job_create = JobCreate(
payload=job_payload,
constraints=request.get_constraints(),
ttl_seconds=ttl_seconds
)
service = JobService(session)
job = service.create_job(client_id, job_create)
return ServiceResponse(
job_id=job.job_id,
service_type=ServiceType.BLENDER,
status=job.state.value,
estimated_completion=job.expires_at.isoformat()
)
# Utility endpoints
@router.get(
"/services",
summary="List available services"
)
async def list_services() -> Dict[str, Any]:
"""List all available service types and their capabilities"""
return {
"services": [
{
"type": ServiceType.WHISPER.value,
"name": "Whisper Speech Recognition",
"description": "Transcribe and translate audio files",
"models": [m.value for m in WhisperModel],
"constraints": {
"gpu": "nvidia",
"min_vram_gb": 1,
}
},
{
"type": ServiceType.STABLE_DIFFUSION.value,
"name": "Stable Diffusion",
"description": "Generate images from text prompts",
"models": [m.value for m in SDModel],
"constraints": {
"gpu": "nvidia",
"min_vram_gb": 4,
}
},
{
"type": ServiceType.LLM_INFERENCE.value,
"name": "LLM Inference",
"description": "Run inference on large language models",
"models": [m.value for m in LLMModel],
"constraints": {
"gpu": "nvidia",
"min_vram_gb": 8,
}
},
{
"type": ServiceType.FFMPEG.value,
"name": "FFmpeg Video Processing",
"description": "Transcode and process video files",
"codecs": [c.value for c in FFmpegCodec],
"constraints": {
"gpu": "any",
"min_vram_gb": 0,
}
},
{
"type": ServiceType.BLENDER.value,
"name": "Blender Rendering",
"description": "Render 3D scenes using Blender",
"engines": [e.value for e in BlenderEngine],
"constraints": {
"gpu": "any",
"min_vram_gb": 4,
}
},
]
}
@router.get(
"/services/{service_type}/schema",
summary="Get service request schema",
deprecated=True
)
async def get_service_schema(service_type: ServiceType) -> Dict[str, Any]:
"""Get the JSON schema for a specific service type
DEPRECATED: Use /v1/registry/services/{service_id}/schema instead.
This endpoint will be removed in version 2.0.
"""
# Get service from registry
service = service_registry.get_service(service_type.value)
if not service:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service {service_type} not found"
)
# Build schema from service definition
properties = {}
required = []
for param in service.input_parameters:
prop = {
"type": param.type.value,
"description": param.description
}
if param.default is not None:
prop["default"] = param.default
if param.min_value is not None:
prop["minimum"] = param.min_value
if param.max_value is not None:
prop["maximum"] = param.max_value
if param.options:
prop["enum"] = param.options
if param.validation:
prop.update(param.validation)
properties[param.name] = prop
if param.required:
required.append(param.name)
schema = {
"type": "object",
"properties": properties,
"required": required
}
return {
"service_type": service_type.value,
"schema": schema
}
async def validate_service_request(service_id: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate a service request against the service schema"""
service = service_registry.get_service(service_id)
if not service:
return {"valid": False, "errors": [f"Service {service_id} not found"]}
validation_result = {
"valid": True,
"errors": [],
"warnings": []
}
# Check required parameters
provided_params = set(request_data.keys())
required_params = {p.name for p in service.input_parameters if p.required}
missing_params = required_params - provided_params
if missing_params:
validation_result["valid"] = False
validation_result["errors"].extend([
f"Missing required parameter: {param}"
for param in missing_params
])
# Validate parameter types and constraints
for param in service.input_parameters:
if param.name in request_data:
value = request_data[param.name]
# Type validation (simplified)
if param.type == "integer" and not isinstance(value, int):
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be an integer"
)
elif param.type == "float" and not isinstance(value, (int, float)):
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be a number"
)
elif param.type == "boolean" and not isinstance(value, bool):
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be a boolean"
)
elif param.type == "array" and not isinstance(value, list):
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be an array"
)
# Value constraints
if param.min_value is not None and value < param.min_value:
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be >= {param.min_value}"
)
if param.max_value is not None and value > param.max_value:
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be <= {param.max_value}"
)
# Enum options
if param.options and value not in param.options:
validation_result["valid"] = False
validation_result["errors"].append(
f"Parameter {param.name} must be one of: {', '.join(param.options)}"
)
return validation_result
# Import models for type hints
from ..models.services import (
WhisperModel,
SDModel,
LLMModel,
FFmpegCodec,
FFmpegPreset,
BlenderEngine,
BlenderFormat,
)

View File

@ -0,0 +1,362 @@
"""
Access control service for confidential transactions
"""
from typing import Dict, List, Optional, Set, Any
from datetime import datetime, timedelta
from enum import Enum
import json
import re
from ..models import ConfidentialAccessRequest, ConfidentialAccessLog
from ..settings import settings
from ..logging import get_logger
logger = get_logger(__name__)
class AccessPurpose(str, Enum):
"""Standard access purposes"""
SETTLEMENT = "settlement"
AUDIT = "audit"
COMPLIANCE = "compliance"
DISPUTE = "dispute"
SUPPORT = "support"
REPORTING = "reporting"
class AccessLevel(str, Enum):
"""Access levels for confidential data"""
READ = "read"
WRITE = "write"
ADMIN = "admin"
class ParticipantRole(str, Enum):
"""Roles for transaction participants"""
CLIENT = "client"
MINER = "miner"
COORDINATOR = "coordinator"
AUDITOR = "auditor"
REGULATOR = "regulator"
class PolicyStore:
"""Storage for access control policies"""
def __init__(self):
self._policies: Dict[str, Dict] = {}
self._role_permissions: Dict[ParticipantRole, Set[str]] = {
ParticipantRole.CLIENT: {"read_own", "settlement_own"},
ParticipantRole.MINER: {"read_assigned", "settlement_assigned"},
ParticipantRole.COORDINATOR: {"read_all", "admin_all"},
ParticipantRole.AUDITOR: {"read_all", "audit_all"},
ParticipantRole.REGULATOR: {"read_all", "compliance_all"}
}
self._load_default_policies()
def _load_default_policies(self):
"""Load default access policies"""
# Client can access their own transactions
self._policies["client_own_data"] = {
"participants": ["client"],
"conditions": {
"transaction_client_id": "{requester}",
"purpose": ["settlement", "dispute", "support"]
},
"access_level": AccessLevel.READ,
"time_restrictions": None
}
# Miner can access assigned transactions
self._policies["miner_assigned_data"] = {
"participants": ["miner"],
"conditions": {
"transaction_miner_id": "{requester}",
"purpose": ["settlement"]
},
"access_level": AccessLevel.READ,
"time_restrictions": None
}
# Coordinator has full access
self._policies["coordinator_full"] = {
"participants": ["coordinator"],
"conditions": {},
"access_level": AccessLevel.ADMIN,
"time_restrictions": None
}
# Auditor access for compliance
self._policies["auditor_compliance"] = {
"participants": ["auditor", "regulator"],
"conditions": {
"purpose": ["audit", "compliance"]
},
"access_level": AccessLevel.READ,
"time_restrictions": {
"business_hours_only": True,
"retention_days": 2555 # 7 years
}
}
def get_policy(self, policy_id: str) -> Optional[Dict]:
"""Get access policy by ID"""
return self._policies.get(policy_id)
def list_policies(self) -> List[str]:
"""List all policy IDs"""
return list(self._policies.keys())
def add_policy(self, policy_id: str, policy: Dict):
"""Add new access policy"""
self._policies[policy_id] = policy
def get_role_permissions(self, role: ParticipantRole) -> Set[str]:
"""Get permissions for a role"""
return self._role_permissions.get(role, set())
class AccessController:
"""Controls access to confidential transaction data"""
def __init__(self, policy_store: PolicyStore):
self.policy_store = policy_store
self._access_cache: Dict[str, Dict] = {}
self._cache_ttl = timedelta(minutes=5)
def verify_access(self, request: ConfidentialAccessRequest) -> bool:
"""Verify if requester has access rights"""
try:
# Check cache first
cache_key = self._get_cache_key(request)
cached_result = self._get_cached_result(cache_key)
if cached_result is not None:
return cached_result["allowed"]
# Get participant info
participant_info = self._get_participant_info(request.requester)
if not participant_info:
logger.warning(f"Unknown participant: {request.requester}")
return False
# Check role-based permissions
role = participant_info.get("role")
if not self._check_role_permissions(role, request):
return False
# Check transaction-specific policies
transaction = self._get_transaction(request.transaction_id)
if not transaction:
logger.warning(f"Transaction not found: {request.transaction_id}")
return False
# Apply access policies
allowed = self._apply_policies(request, participant_info, transaction)
# Cache result
self._cache_result(cache_key, allowed)
return allowed
except Exception as e:
logger.error(f"Access verification failed: {e}")
return False
def _check_role_permissions(self, role: str, request: ConfidentialAccessRequest) -> bool:
"""Check if role grants access for this purpose"""
try:
participant_role = ParticipantRole(role.lower())
permissions = self.policy_store.get_role_permissions(participant_role)
# Check purpose-based permissions
if request.purpose == "settlement":
return "settlement" in permissions or "settlement_own" in permissions
elif request.purpose == "audit":
return "audit" in permissions or "audit_all" in permissions
elif request.purpose == "compliance":
return "compliance" in permissions or "compliance_all" in permissions
elif request.purpose == "dispute":
return "dispute" in permissions or "read_own" in permissions
elif request.purpose == "support":
return "support" in permissions or "read_all" in permissions
else:
return "read" in permissions or "read_all" in permissions
except ValueError:
logger.warning(f"Invalid role: {role}")
return False
def _apply_policies(
self,
request: ConfidentialAccessRequest,
participant_info: Dict,
transaction: Dict
) -> bool:
"""Apply access policies to request"""
# Check if participant is in transaction participants list
if request.requester not in transaction.get("participants", []):
# Only coordinators, auditors, and regulators can access non-participant data
role = participant_info.get("role", "").lower()
if role not in ["coordinator", "auditor", "regulator"]:
return False
# Check time-based restrictions
if not self._check_time_restrictions(request.purpose, participant_info.get("role")):
return False
# Check business hours for auditors
if participant_info.get("role") == "auditor" and not self._is_business_hours():
return False
# Check retention periods
if not self._check_retention_period(transaction, participant_info.get("role")):
return False
return True
def _check_time_restrictions(self, purpose: str, role: Optional[str]) -> bool:
"""Check time-based access restrictions"""
# No restrictions for settlement and dispute
if purpose in ["settlement", "dispute"]:
return True
# Audit and compliance only during business hours for non-coordinators
if purpose in ["audit", "compliance"] and role not in ["coordinator"]:
return self._is_business_hours()
return True
def _is_business_hours(self) -> bool:
"""Check if current time is within business hours"""
now = datetime.utcnow()
# Monday-Friday, 9 AM - 5 PM UTC
if now.weekday() >= 5: # Weekend
return False
if 9 <= now.hour < 17:
return True
return False
def _check_retention_period(self, transaction: Dict, role: Optional[str]) -> bool:
"""Check if data is within retention period for role"""
transaction_date = transaction.get("timestamp", datetime.utcnow())
# Different retention periods for different roles
if role == "regulator":
retention_days = 2555 # 7 years
elif role == "auditor":
retention_days = 1825 # 5 years
elif role == "coordinator":
retention_days = 3650 # 10 years
else:
retention_days = 365 # 1 year
expiry_date = transaction_date + timedelta(days=retention_days)
return datetime.utcnow() <= expiry_date
def _get_participant_info(self, participant_id: str) -> Optional[Dict]:
"""Get participant information"""
# In production, query from database
# For now, return mock data
if participant_id.startswith("client-"):
return {"id": participant_id, "role": "client", "active": True}
elif participant_id.startswith("miner-"):
return {"id": participant_id, "role": "miner", "active": True}
elif participant_id.startswith("coordinator-"):
return {"id": participant_id, "role": "coordinator", "active": True}
elif participant_id.startswith("auditor-"):
return {"id": participant_id, "role": "auditor", "active": True}
elif participant_id.startswith("regulator-"):
return {"id": participant_id, "role": "regulator", "active": True}
else:
return None
def _get_transaction(self, transaction_id: str) -> Optional[Dict]:
"""Get transaction information"""
# In production, query from database
# For now, return mock data
return {
"transaction_id": transaction_id,
"participants": ["client-456", "miner-789"],
"timestamp": datetime.utcnow(),
"status": "completed"
}
def _get_cache_key(self, request: ConfidentialAccessRequest) -> str:
"""Generate cache key for access request"""
return f"{request.requester}:{request.transaction_id}:{request.purpose}"
def _get_cached_result(self, cache_key: str) -> Optional[Dict]:
"""Get cached access result"""
if cache_key in self._access_cache:
cached = self._access_cache[cache_key]
if datetime.utcnow() - cached["timestamp"] < self._cache_ttl:
return cached
else:
del self._access_cache[cache_key]
return None
def _cache_result(self, cache_key: str, allowed: bool):
"""Cache access result"""
self._access_cache[cache_key] = {
"allowed": allowed,
"timestamp": datetime.utcnow()
}
def create_access_policy(
self,
name: str,
participants: List[str],
conditions: Dict[str, Any],
access_level: AccessLevel
) -> str:
"""Create a new access policy"""
policy_id = f"policy_{datetime.utcnow().timestamp()}"
policy = {
"participants": participants,
"conditions": conditions,
"access_level": access_level,
"time_restrictions": conditions.get("time_restrictions"),
"created_at": datetime.utcnow().isoformat()
}
self.policy_store.add_policy(policy_id, policy)
logger.info(f"Created access policy: {policy_id}")
return policy_id
def revoke_access(self, participant_id: str, transaction_id: Optional[str] = None):
"""Revoke access for participant"""
# In production, update database
# For now, clear cache
keys_to_remove = []
for key in self._access_cache:
if key.startswith(f"{participant_id}:"):
if transaction_id is None or key.split(":")[1] == transaction_id:
keys_to_remove.append(key)
for key in keys_to_remove:
del self._access_cache[key]
logger.info(f"Revoked access for participant: {participant_id}")
def get_access_summary(self, participant_id: str) -> Dict:
"""Get summary of participant's access rights"""
participant_info = self._get_participant_info(participant_id)
if not participant_info:
return {"error": "Participant not found"}
role = participant_info.get("role")
permissions = self.policy_store.get_role_permissions(ParticipantRole(role))
return {
"participant_id": participant_id,
"role": role,
"permissions": list(permissions),
"active": participant_info.get("active", False)
}

View File

@ -0,0 +1,532 @@
"""
Audit logging service for privacy compliance
"""
import os
import json
import hashlib
import gzip
import asyncio
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
from pathlib import Path
from dataclasses import dataclass, asdict
from ..models import ConfidentialAccessLog
from ..settings import settings
from ..logging import get_logger
logger = get_logger(__name__)
@dataclass
class AuditEvent:
"""Structured audit event"""
event_id: str
timestamp: datetime
event_type: str
participant_id: str
transaction_id: Optional[str]
action: str
resource: str
outcome: str
details: Dict[str, Any]
ip_address: Optional[str]
user_agent: Optional[str]
authorization: Optional[str]
signature: Optional[str]
class AuditLogger:
"""Tamper-evident audit logging for privacy compliance"""
def __init__(self, log_dir: str = "/var/log/aitbc/audit"):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
# Current log file
self.current_file = None
self.current_hash = None
# Async writer task
self.write_queue = asyncio.Queue(maxsize=10000)
self.writer_task = None
# Chain of hashes for integrity
self.chain_hash = self._load_chain_hash()
async def start(self):
"""Start the background writer task"""
if self.writer_task is None:
self.writer_task = asyncio.create_task(self._background_writer())
async def stop(self):
"""Stop the background writer task"""
if self.writer_task:
self.writer_task.cancel()
try:
await self.writer_task
except asyncio.CancelledError:
pass
self.writer_task = None
async def log_access(
self,
participant_id: str,
transaction_id: Optional[str],
action: str,
outcome: str,
details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
authorization: Optional[str] = None
):
"""Log access to confidential data"""
event = AuditEvent(
event_id=self._generate_event_id(),
timestamp=datetime.utcnow(),
event_type="access",
participant_id=participant_id,
transaction_id=transaction_id,
action=action,
resource="confidential_transaction",
outcome=outcome,
details=details or {},
ip_address=ip_address,
user_agent=user_agent,
authorization=authorization,
signature=None
)
# Add signature for tamper-evidence
event.signature = self._sign_event(event)
# Queue for writing
await self.write_queue.put(event)
async def log_key_operation(
self,
participant_id: str,
operation: str,
key_version: int,
outcome: str,
details: Optional[Dict[str, Any]] = None
):
"""Log key management operations"""
event = AuditEvent(
event_id=self._generate_event_id(),
timestamp=datetime.utcnow(),
event_type="key_operation",
participant_id=participant_id,
transaction_id=None,
action=operation,
resource="encryption_key",
outcome=outcome,
details={**(details or {}), "key_version": key_version},
ip_address=None,
user_agent=None,
authorization=None,
signature=None
)
event.signature = self._sign_event(event)
await self.write_queue.put(event)
async def log_policy_change(
self,
participant_id: str,
policy_id: str,
change_type: str,
outcome: str,
details: Optional[Dict[str, Any]] = None
):
"""Log access policy changes"""
event = AuditEvent(
event_id=self._generate_event_id(),
timestamp=datetime.utcnow(),
event_type="policy_change",
participant_id=participant_id,
transaction_id=None,
action=change_type,
resource="access_policy",
outcome=outcome,
details={**(details or {}), "policy_id": policy_id},
ip_address=None,
user_agent=None,
authorization=None,
signature=None
)
event.signature = self._sign_event(event)
await self.write_queue.put(event)
def query_logs(
self,
participant_id: Optional[str] = None,
transaction_id: Optional[str] = None,
event_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 100
) -> List[AuditEvent]:
"""Query audit logs"""
results = []
# Get list of log files to search
log_files = self._get_log_files(start_time, end_time)
for log_file in log_files:
try:
# Read and decompress if needed
if log_file.suffix == ".gz":
with gzip.open(log_file, "rt") as f:
for line in f:
event = self._parse_log_line(line.strip())
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
results.append(event)
if len(results) >= limit:
return results
else:
with open(log_file, "r") as f:
for line in f:
event = self._parse_log_line(line.strip())
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
results.append(event)
if len(results) >= limit:
return results
except Exception as e:
logger.error(f"Failed to read log file {log_file}: {e}")
continue
# Sort by timestamp (newest first)
results.sort(key=lambda x: x.timestamp, reverse=True)
return results[:limit]
def verify_integrity(self, start_date: Optional[datetime] = None) -> Dict[str, Any]:
"""Verify integrity of audit logs"""
if start_date is None:
start_date = datetime.utcnow() - timedelta(days=30)
results = {
"verified_files": 0,
"total_files": 0,
"integrity_violations": [],
"chain_valid": True
}
log_files = self._get_log_files(start_date)
for log_file in log_files:
results["total_files"] += 1
try:
# Verify file hash
file_hash = self._calculate_file_hash(log_file)
stored_hash = self._get_stored_hash(log_file)
if file_hash != stored_hash:
results["integrity_violations"].append({
"file": str(log_file),
"expected": stored_hash,
"actual": file_hash
})
results["chain_valid"] = False
else:
results["verified_files"] += 1
except Exception as e:
logger.error(f"Failed to verify {log_file}: {e}")
results["integrity_violations"].append({
"file": str(log_file),
"error": str(e)
})
results["chain_valid"] = False
return results
def export_logs(
self,
start_time: datetime,
end_time: datetime,
format: str = "json",
include_signatures: bool = True
) -> str:
"""Export audit logs for compliance reporting"""
events = self.query_logs(
start_time=start_time,
end_time=end_time,
limit=10000
)
if format == "json":
export_data = {
"export_metadata": {
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"event_count": len(events),
"exported_at": datetime.utcnow().isoformat(),
"include_signatures": include_signatures
},
"events": []
}
for event in events:
event_dict = asdict(event)
event_dict["timestamp"] = event.timestamp.isoformat()
if not include_signatures:
event_dict.pop("signature", None)
export_data["events"].append(event_dict)
return json.dumps(export_data, indent=2)
elif format == "csv":
import csv
import io
output = io.StringIO()
writer = csv.writer(output)
# Header
header = [
"event_id", "timestamp", "event_type", "participant_id",
"transaction_id", "action", "resource", "outcome",
"ip_address", "user_agent"
]
if include_signatures:
header.append("signature")
writer.writerow(header)
# Events
for event in events:
row = [
event.event_id,
event.timestamp.isoformat(),
event.event_type,
event.participant_id,
event.transaction_id,
event.action,
event.resource,
event.outcome,
event.ip_address,
event.user_agent
]
if include_signatures:
row.append(event.signature)
writer.writerow(row)
return output.getvalue()
else:
raise ValueError(f"Unsupported export format: {format}")
async def _background_writer(self):
"""Background task for writing audit events"""
while True:
try:
# Get batch of events
events = []
while len(events) < 100:
try:
# Use asyncio.wait_for for timeout
event = await asyncio.wait_for(
self.write_queue.get(),
timeout=1.0
)
events.append(event)
except asyncio.TimeoutError:
if events:
break
continue
# Write events
if events:
self._write_events(events)
except Exception as e:
logger.error(f"Background writer error: {e}")
# Brief pause to avoid error loops
await asyncio.sleep(1)
def _write_events(self, events: List[AuditEvent]):
"""Write events to current log file"""
try:
self._rotate_if_needed()
with open(self.current_file, "a") as f:
for event in events:
# Convert to JSON line
event_dict = asdict(event)
event_dict["timestamp"] = event.timestamp.isoformat()
# Write with signature
line = json.dumps(event_dict, separators=(",", ":")) + "\n"
f.write(line)
f.flush()
# Update chain hash
self._update_chain_hash(events[-1])
except Exception as e:
logger.error(f"Failed to write audit events: {e}")
def _rotate_if_needed(self):
"""Rotate log file if needed"""
now = datetime.utcnow()
today = now.date()
# Check if we need a new file
if self.current_file is None:
self._new_log_file(today)
else:
file_date = datetime.fromisoformat(
self.current_file.stem.split("_")[1]
).date()
if file_date != today:
self._new_log_file(today)
def _new_log_file(self, date):
"""Create new log file for date"""
filename = f"audit_{date.isoformat()}.log"
self.current_file = self.log_dir / filename
# Write header with metadata
if not self.current_file.exists():
header = {
"created_at": datetime.utcnow().isoformat(),
"version": "1.0",
"format": "jsonl",
"previous_hash": self.chain_hash
}
with open(self.current_file, "w") as f:
f.write(f"# {json.dumps(header)}\n")
def _generate_event_id(self) -> str:
"""Generate unique event ID"""
return f"evt_{datetime.utcnow().timestamp()}_{os.urandom(4).hex()}"
def _sign_event(self, event: AuditEvent) -> str:
"""Sign event for tamper-evidence"""
# Create canonical representation
event_data = {
"event_id": event.event_id,
"timestamp": event.timestamp.isoformat(),
"participant_id": event.participant_id,
"action": event.action,
"outcome": event.outcome
}
# Hash with previous chain hash
data = json.dumps(event_data, separators=(",", ":"), sort_keys=True)
combined = f"{self.chain_hash}:{data}".encode()
return hashlib.sha256(combined).hexdigest()
def _update_chain_hash(self, last_event: AuditEvent):
"""Update chain hash with new event"""
self.chain_hash = last_event.signature or self.chain_hash
# Store chain hash for integrity checking
chain_file = self.log_dir / "chain.hash"
with open(chain_file, "w") as f:
f.write(self.chain_hash)
def _load_chain_hash(self) -> str:
"""Load previous chain hash"""
chain_file = self.log_dir / "chain.hash"
if chain_file.exists():
with open(chain_file, "r") as f:
return f.read().strip()
return "0" * 64 # Initial hash
def _get_log_files(self, start_time: Optional[datetime], end_time: Optional[datetime]) -> List[Path]:
"""Get list of log files to search"""
files = []
for file in self.log_dir.glob("audit_*.log*"):
try:
# Extract date from filename
date_str = file.stem.split("_")[1]
file_date = datetime.fromisoformat(date_str).date()
# Check if file is in range
file_start = datetime.combine(file_date, datetime.min.time())
file_end = file_start + timedelta(days=1)
if (not start_time or file_end >= start_time) and \
(not end_time or file_start <= end_time):
files.append(file)
except Exception:
continue
return sorted(files)
def _parse_log_line(self, line: str) -> Optional[AuditEvent]:
"""Parse log line into event"""
if line.startswith("#"):
return None # Skip header
try:
data = json.loads(line)
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
return AuditEvent(**data)
except Exception as e:
logger.error(f"Failed to parse log line: {e}")
return None
def _matches_query(
self,
event: Optional[AuditEvent],
participant_id: Optional[str],
transaction_id: Optional[str],
event_type: Optional[str],
start_time: Optional[datetime],
end_time: Optional[datetime]
) -> bool:
"""Check if event matches query criteria"""
if not event:
return False
if participant_id and event.participant_id != participant_id:
return False
if transaction_id and event.transaction_id != transaction_id:
return False
if event_type and event.event_type != event_type:
return False
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
def _calculate_file_hash(self, file_path: Path) -> str:
"""Calculate SHA-256 hash of file"""
hash_sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def _get_stored_hash(self, file_path: Path) -> str:
"""Get stored hash for file"""
hash_file = file_path.with_suffix(".hash")
if hash_file.exists():
with open(hash_file, "r") as f:
return f.read().strip()
return ""
# Global audit logger instance
audit_logger = AuditLogger()

View File

@ -0,0 +1,349 @@
"""
Encryption service for confidential transactions
"""
import os
import json
import base64
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime, timedelta
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
from ..models import ConfidentialTransaction, AccessLog
from ..settings import settings
from ..logging import get_logger
logger = get_logger(__name__)
class EncryptedData:
"""Container for encrypted data and keys"""
def __init__(
self,
ciphertext: bytes,
encrypted_keys: Dict[str, bytes],
algorithm: str = "AES-256-GCM+X25519",
nonce: Optional[bytes] = None,
tag: Optional[bytes] = None
):
self.ciphertext = ciphertext
self.encrypted_keys = encrypted_keys
self.algorithm = algorithm
self.nonce = nonce
self.tag = tag
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage"""
return {
"ciphertext": base64.b64encode(self.ciphertext).decode(),
"encrypted_keys": {
participant: base64.b64encode(key).decode()
for participant, key in self.encrypted_keys.items()
},
"algorithm": self.algorithm,
"nonce": base64.b64encode(self.nonce).decode() if self.nonce else None,
"tag": base64.b64encode(self.tag).decode() if self.tag else None
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "EncryptedData":
"""Create from dictionary"""
return cls(
ciphertext=base64.b64decode(data["ciphertext"]),
encrypted_keys={
participant: base64.b64decode(key)
for participant, key in data["encrypted_keys"].items()
},
algorithm=data["algorithm"],
nonce=base64.b64decode(data["nonce"]) if data.get("nonce") else None,
tag=base64.b64decode(data["tag"]) if data.get("tag") else None
)
class EncryptionService:
"""Service for encrypting/decrypting confidential transaction data"""
def __init__(self, key_manager: "KeyManager"):
self.key_manager = key_manager
self.backend = default_backend()
self.algorithm = "AES-256-GCM+X25519"
def encrypt(
self,
data: Dict[str, Any],
participants: List[str],
include_audit: bool = True
) -> EncryptedData:
"""Encrypt data for multiple participants
Args:
data: Data to encrypt
participants: List of participant IDs who can decrypt
include_audit: Whether to include audit escrow key
Returns:
EncryptedData container with ciphertext and encrypted keys
"""
try:
# Generate random DEK (Data Encryption Key)
dek = os.urandom(32) # 256-bit key for AES-256
nonce = os.urandom(12) # 96-bit nonce for GCM
# Serialize and encrypt data
plaintext = json.dumps(data, separators=(",", ":")).encode()
aesgcm = AESGCM(dek)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
# Extract tag (included in ciphertext for GCM)
tag = ciphertext[-16:]
actual_ciphertext = ciphertext[:-16]
# Encrypt DEK for each participant
encrypted_keys = {}
for participant in participants:
try:
public_key = self.key_manager.get_public_key(participant)
encrypted_dek = self._encrypt_dek(dek, public_key)
encrypted_keys[participant] = encrypted_dek
except Exception as e:
logger.error(f"Failed to encrypt DEK for participant {participant}: {e}")
continue
# Add audit escrow if requested
if include_audit:
try:
audit_public_key = self.key_manager.get_audit_key()
encrypted_dek = self._encrypt_dek(dek, audit_public_key)
encrypted_keys["audit"] = encrypted_dek
except Exception as e:
logger.error(f"Failed to encrypt DEK for audit: {e}")
return EncryptedData(
ciphertext=actual_ciphertext,
encrypted_keys=encrypted_keys,
algorithm=self.algorithm,
nonce=nonce,
tag=tag
)
except Exception as e:
logger.error(f"Encryption failed: {e}")
raise EncryptionError(f"Failed to encrypt data: {e}")
def decrypt(
self,
encrypted_data: EncryptedData,
participant_id: str,
purpose: str = "access"
) -> Dict[str, Any]:
"""Decrypt data for a specific participant
Args:
encrypted_data: The encrypted data container
participant_id: ID of the participant requesting decryption
purpose: Purpose of decryption for audit logging
Returns:
Decrypted data as dictionary
"""
try:
# Get participant's private key
private_key = self.key_manager.get_private_key(participant_id)
# Get encrypted DEK for participant
if participant_id not in encrypted_data.encrypted_keys:
raise AccessDeniedError(f"Participant {participant_id} not authorized")
encrypted_dek = encrypted_data.encrypted_keys[participant_id]
# Decrypt DEK
dek = self._decrypt_dek(encrypted_dek, private_key)
# Reconstruct ciphertext with tag
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
# Decrypt data
aesgcm = AESGCM(dek)
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
data = json.loads(plaintext.decode())
# Log access
self._log_access(
transaction_id=None, # Will be set by caller
participant_id=participant_id,
purpose=purpose,
success=True
)
return data
except Exception as e:
logger.error(f"Decryption failed for participant {participant_id}: {e}")
self._log_access(
transaction_id=None,
participant_id=participant_id,
purpose=purpose,
success=False,
error=str(e)
)
raise DecryptionError(f"Failed to decrypt data: {e}")
def audit_decrypt(
self,
encrypted_data: EncryptedData,
audit_authorization: str,
purpose: str = "audit"
) -> Dict[str, Any]:
"""Decrypt data for audit purposes
Args:
encrypted_data: The encrypted data container
audit_authorization: Authorization token for audit access
purpose: Purpose of decryption
Returns:
Decrypted data as dictionary
"""
try:
# Verify audit authorization
if not self.key_manager.verify_audit_authorization(audit_authorization):
raise AccessDeniedError("Invalid audit authorization")
# Get audit private key
audit_private_key = self.key_manager.get_audit_private_key(audit_authorization)
# Decrypt using audit key
if "audit" not in encrypted_data.encrypted_keys:
raise AccessDeniedError("Audit escrow not available")
encrypted_dek = encrypted_data.encrypted_keys["audit"]
dek = self._decrypt_dek(encrypted_dek, audit_private_key)
# Decrypt data
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
aesgcm = AESGCM(dek)
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
data = json.loads(plaintext.decode())
# Log audit access
self._log_access(
transaction_id=None,
participant_id="audit",
purpose=f"audit:{purpose}",
success=True,
authorization=audit_authorization
)
return data
except Exception as e:
logger.error(f"Audit decryption failed: {e}")
raise DecryptionError(f"Failed to decrypt for audit: {e}")
def _encrypt_dek(self, dek: bytes, public_key: X25519PublicKey) -> bytes:
"""Encrypt DEK using ECIES with X25519"""
# Generate ephemeral key pair
ephemeral_private = X25519PrivateKey.generate()
ephemeral_public = ephemeral_private.public_key()
# Perform ECDH
shared_key = ephemeral_private.exchange(public_key)
# Derive encryption key from shared secret
derived_key = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b"AITBC-DEK-Encryption",
backend=self.backend
).derive(shared_key)
# Encrypt DEK with AES-GCM
aesgcm = AESGCM(derived_key)
nonce = os.urandom(12)
encrypted_dek = aesgcm.encrypt(nonce, dek, None)
# Return ephemeral public key + nonce + encrypted DEK
return (
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw) +
nonce +
encrypted_dek
)
def _decrypt_dek(self, encrypted_dek: bytes, private_key: X25519PrivateKey) -> bytes:
"""Decrypt DEK using ECIES with X25519"""
# Extract components
ephemeral_public_bytes = encrypted_dek[:32]
nonce = encrypted_dek[32:44]
dek_ciphertext = encrypted_dek[44:]
# Reconstruct ephemeral public key
ephemeral_public = X25519PublicKey.from_public_bytes(ephemeral_public_bytes)
# Perform ECDH
shared_key = private_key.exchange(ephemeral_public)
# Derive decryption key
derived_key = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b"AITBC-DEK-Encryption",
backend=self.backend
).derive(shared_key)
# Decrypt DEK
aesgcm = AESGCM(derived_key)
dek = aesgcm.decrypt(nonce, dek_ciphertext, None)
return dek
def _log_access(
self,
transaction_id: Optional[str],
participant_id: str,
purpose: str,
success: bool,
error: Optional[str] = None,
authorization: Optional[str] = None
):
"""Log access to confidential data"""
try:
log_entry = {
"transaction_id": transaction_id,
"participant_id": participant_id,
"purpose": purpose,
"timestamp": datetime.utcnow().isoformat(),
"success": success,
"error": error,
"authorization": authorization
}
# In production, this would go to secure audit log
logger.info(f"Confidential data access: {json.dumps(log_entry)}")
except Exception as e:
logger.error(f"Failed to log access: {e}")
class EncryptionError(Exception):
"""Base exception for encryption errors"""
pass
class DecryptionError(EncryptionError):
"""Exception for decryption errors"""
pass
class AccessDeniedError(EncryptionError):
"""Exception for access denied errors"""
pass

View File

@ -0,0 +1,435 @@
"""
HSM-backed key management for production use
"""
import os
import json
from typing import Dict, List, Optional, Tuple
from datetime import datetime
from abc import ABC, abstractmethod
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from cryptography.hazmat.backends import default_backend
from ..models import KeyPair, KeyRotationLog, AuditAuthorization
from ..repositories.confidential import (
ParticipantKeyRepository,
KeyRotationRepository
)
from ..settings import settings
from ..logging import get_logger
logger = get_logger(__name__)
class HSMProvider(ABC):
"""Abstract base class for HSM providers"""
@abstractmethod
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
"""Generate key pair in HSM, return (public_key, key_handle)"""
pass
@abstractmethod
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
"""Sign data with HSM-stored private key"""
pass
@abstractmethod
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
"""Derive shared secret using ECDH"""
pass
@abstractmethod
async def delete_key(self, key_handle: bytes) -> bool:
"""Delete key from HSM"""
pass
@abstractmethod
async def list_keys(self) -> List[str]:
"""List all key IDs in HSM"""
pass
class SoftwareHSMProvider(HSMProvider):
"""Software-based HSM provider for development/testing"""
def __init__(self):
self._keys: Dict[str, X25519PrivateKey] = {}
self._backend = default_backend()
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
"""Generate key pair in memory"""
private_key = X25519PrivateKey.generate()
public_key = private_key.public_key()
# Store private key (in production, this would be in secure hardware)
self._keys[key_id] = private_key
return (
public_key.public_bytes(Encoding.Raw, PublicFormat.Raw),
key_id.encode() # Use key_id as handle
)
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
"""Sign with stored private key"""
key_id = key_handle.decode()
private_key = self._keys.get(key_id)
if not private_key:
raise ValueError(f"Key not found: {key_id}")
# For X25519, we don't sign - we exchange
# This is a placeholder for actual HSM operations
return b"signature_placeholder"
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
"""Derive shared secret"""
key_id = key_handle.decode()
private_key = self._keys.get(key_id)
if not private_key:
raise ValueError(f"Key not found: {key_id}")
peer_public = X25519PublicKey.from_public_bytes(public_key)
return private_key.exchange(peer_public)
async def delete_key(self, key_handle: bytes) -> bool:
"""Delete key from memory"""
key_id = key_handle.decode()
if key_id in self._keys:
del self._keys[key_id]
return True
return False
async def list_keys(self) -> List[str]:
"""List all keys"""
return list(self._keys.keys())
class AzureKeyVaultProvider(HSMProvider):
"""Azure Key Vault HSM provider for production"""
def __init__(self, vault_url: str, credential):
from azure.keyvault.keys.crypto import CryptographyClient
from azure.keyvault.keys import KeyClient
from azure.identity import DefaultAzureCredential
self.vault_url = vault_url
self.credential = credential or DefaultAzureCredential()
self.key_client = KeyClient(vault_url, self.credential)
self.crypto_client = None
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
"""Generate key in Azure Key Vault"""
# Create EC-HSM key
key = await self.key_client.create_ec_key(
key_id,
curve="P-256" # Azure doesn't support X25519 directly
)
# Get public key
public_key = key.key.cryptography_client.public_key()
public_bytes = public_key.public_bytes(
Encoding.Raw,
PublicFormat.Raw
)
return public_bytes, key.id.encode()
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
"""Sign with Azure Key Vault"""
key_id = key_handle.decode()
crypto_client = self.key_client.get_cryptography_client(key_id)
sign_result = await crypto_client.sign("ES256", data)
return sign_result.signature
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
"""Derive shared secret (not directly supported in Azure)"""
# Would need to use a different approach
raise NotImplementedError("ECDH not supported in Azure Key Vault")
async def delete_key(self, key_handle: bytes) -> bool:
"""Delete key from Azure Key Vault"""
key_name = key_handle.decode().split("/")[-1]
await self.key_client.begin_delete_key(key_name)
return True
async def list_keys(self) -> List[str]:
"""List keys in Azure Key Vault"""
keys = []
async for key in self.key_client.list_properties_of_keys():
keys.append(key.name)
return keys
class AWSKMSProvider(HSMProvider):
"""AWS KMS HSM provider for production"""
def __init__(self, region_name: str):
import boto3
self.kms = boto3.client('kms', region_name=region_name)
async def generate_key(self, key_id: str) -> Tuple[bytes, bytes]:
"""Generate key pair in AWS KMS"""
# Create CMK
response = self.kms.create_key(
Description=f"AITBC confidential transaction key for {key_id}",
KeyUsage='ENCRYPT_DECRYPT',
KeySpec='ECC_NIST_P256'
)
# Get public key
public_key = self.kms.get_public_key(KeyId=response['KeyMetadata']['KeyId'])
return public_key['PublicKey'], response['KeyMetadata']['KeyId'].encode()
async def sign_with_key(self, key_handle: bytes, data: bytes) -> bytes:
"""Sign with AWS KMS"""
response = self.kms.sign(
KeyId=key_handle.decode(),
Message=data,
MessageType='RAW',
SigningAlgorithm='ECDSA_SHA_256'
)
return response['Signature']
async def derive_shared_secret(self, key_handle: bytes, public_key: bytes) -> bytes:
"""Derive shared secret (not directly supported in KMS)"""
raise NotImplementedError("ECDH not supported in AWS KMS")
async def delete_key(self, key_handle: bytes) -> bool:
"""Schedule key deletion in AWS KMS"""
self.kms.schedule_key_deletion(KeyId=key_handle.decode())
return True
async def list_keys(self) -> List[str]:
"""List keys in AWS KMS"""
keys = []
paginator = self.kms.get_paginator('list_keys')
for page in paginator.paginate():
for key in page['Keys']:
keys.append(key['KeyId'])
return keys
class HSMKeyManager:
"""HSM-backed key manager for production"""
def __init__(self, hsm_provider: HSMProvider, key_repository: ParticipantKeyRepository):
self.hsm = hsm_provider
self.key_repo = key_repository
self._master_key = None
self._init_master_key()
def _init_master_key(self):
"""Initialize master key for encrypting stored data"""
# In production, this would come from HSM or KMS
self._master_key = os.urandom(32)
async def generate_key_pair(self, participant_id: str) -> KeyPair:
"""Generate key pair in HSM"""
try:
# Generate key in HSM
hsm_key_id = f"aitbc-{participant_id}-{datetime.utcnow().timestamp()}"
public_key_bytes, key_handle = await self.hsm.generate_key(hsm_key_id)
# Create key pair record
key_pair = KeyPair(
participant_id=participant_id,
private_key=key_handle, # Store HSM handle, not actual private key
public_key=public_key_bytes,
algorithm="X25519",
created_at=datetime.utcnow(),
version=1
)
# Store metadata in database
await self.key_repo.create(
await self._get_session(),
key_pair
)
logger.info(f"Generated HSM key pair for participant: {participant_id}")
return key_pair
except Exception as e:
logger.error(f"Failed to generate HSM key pair for {participant_id}: {e}")
raise
async def rotate_keys(self, participant_id: str) -> KeyPair:
"""Rotate keys in HSM"""
# Get current key
current_key = await self.key_repo.get_by_participant(
await self._get_session(),
participant_id
)
if not current_key:
raise ValueError(f"No existing keys for {participant_id}")
# Generate new key
new_key_pair = await self.generate_key_pair(participant_id)
# Log rotation
rotation_log = KeyRotationLog(
participant_id=participant_id,
old_version=current_key.version,
new_version=new_key_pair.version,
rotated_at=datetime.utcnow(),
reason="scheduled_rotation"
)
await self.key_repo.rotate(
await self._get_session(),
participant_id,
new_key_pair
)
# Delete old key from HSM
await self.hsm.delete_key(current_key.private_key)
return new_key_pair
def get_public_key(self, participant_id: str) -> X25519PublicKey:
"""Get public key for participant"""
key = self.key_repo.get_by_participant_sync(participant_id)
if not key:
raise ValueError(f"No keys found for {participant_id}")
return X25519PublicKey.from_public_bytes(key.public_key)
async def get_private_key_handle(self, participant_id: str) -> bytes:
"""Get HSM key handle for participant"""
key = await self.key_repo.get_by_participant(
await self._get_session(),
participant_id
)
if not key:
raise ValueError(f"No keys found for {participant_id}")
return key.private_key # This is the HSM handle
async def derive_shared_secret(
self,
participant_id: str,
peer_public_key: bytes
) -> bytes:
"""Derive shared secret using HSM"""
key_handle = await self.get_private_key_handle(participant_id)
return await self.hsm.derive_shared_secret(key_handle, peer_public_key)
async def sign_with_key(
self,
participant_id: str,
data: bytes
) -> bytes:
"""Sign data using HSM-stored key"""
key_handle = await self.get_private_key_handle(participant_id)
return await self.hsm.sign_with_key(key_handle, data)
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
"""Revoke participant's keys"""
# Get current key
current_key = await self.key_repo.get_by_participant(
await self._get_session(),
participant_id
)
if not current_key:
return False
# Delete from HSM
await self.hsm.delete_key(current_key.private_key)
# Mark as revoked in database
return await self.key_repo.update_active(
await self._get_session(),
participant_id,
False,
reason
)
async def create_audit_authorization(
self,
issuer: str,
purpose: str,
expires_in_hours: int = 24
) -> str:
"""Create audit authorization signed with HSM"""
# Create authorization payload
payload = {
"issuer": issuer,
"subject": "audit_access",
"purpose": purpose,
"created_at": datetime.utcnow().isoformat(),
"expires_at": (datetime.utcnow() + timedelta(hours=expires_in_hours)).isoformat()
}
# Sign with audit key
audit_key_handle = await self.get_private_key_handle("audit")
signature = await self.hsm.sign_with_key(
audit_key_handle,
json.dumps(payload).encode()
)
payload["signature"] = signature.hex()
# Encode for transport
import base64
return base64.b64encode(json.dumps(payload).encode()).decode()
async def verify_audit_authorization(self, authorization: str) -> bool:
"""Verify audit authorization"""
try:
# Decode authorization
import base64
auth_data = base64.b64decode(authorization).decode()
auth_json = json.loads(auth_data)
# Check expiration
expires_at = datetime.fromisoformat(auth_json["expires_at"])
if datetime.utcnow() > expires_at:
return False
# Verify signature with audit public key
audit_public_key = self.get_public_key("audit")
# In production, verify with proper cryptographic library
return True
except Exception as e:
logger.error(f"Failed to verify audit authorization: {e}")
return False
async def _get_session(self):
"""Get database session"""
# In production, inject via dependency injection
async for session in get_async_session():
return session
def create_hsm_key_manager() -> HSMKeyManager:
"""Create HSM key manager based on configuration"""
from ..repositories.confidential import ParticipantKeyRepository
# Get HSM provider from settings
hsm_type = getattr(settings, 'HSM_PROVIDER', 'software')
if hsm_type == 'software':
hsm = SoftwareHSMProvider()
elif hsm_type == 'azure':
vault_url = getattr(settings, 'AZURE_KEY_VAULT_URL')
hsm = AzureKeyVaultProvider(vault_url)
elif hsm_type == 'aws':
region = getattr(settings, 'AWS_REGION', 'us-east-1')
hsm = AWSKMSProvider(region)
else:
raise ValueError(f"Unknown HSM provider: {hsm_type}")
key_repo = ParticipantKeyRepository()
return HSMKeyManager(hsm, key_repo)

View File

@ -0,0 +1,466 @@
"""
Key management service for confidential transactions
"""
import os
import json
import base64
from typing import Dict, Optional, List, Tuple
from datetime import datetime, timedelta
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ..models import KeyPair, KeyRotationLog, AuditAuthorization
from ..settings import settings
from ..logging import get_logger
logger = get_logger(__name__)
class KeyManager:
"""Manages encryption keys for confidential transactions"""
def __init__(self, storage_backend: "KeyStorageBackend"):
self.storage = storage_backend
self.backend = default_backend()
self._key_cache = {}
self._audit_key = None
self._audit_key_rotation = timedelta(days=30)
async def generate_key_pair(self, participant_id: str) -> KeyPair:
"""Generate X25519 key pair for participant"""
try:
# Generate new key pair
private_key = X25519PrivateKey.generate()
public_key = private_key.public_key()
# Create key pair object
key_pair = KeyPair(
participant_id=participant_id,
private_key=private_key.private_bytes_raw(),
public_key=public_key.public_bytes_raw(),
algorithm="X25519",
created_at=datetime.utcnow(),
version=1
)
# Store securely
await self.storage.store_key_pair(key_pair)
# Cache public key
self._key_cache[participant_id] = {
"public_key": public_key,
"version": key_pair.version
}
logger.info(f"Generated key pair for participant: {participant_id}")
return key_pair
except Exception as e:
logger.error(f"Failed to generate key pair for {participant_id}: {e}")
raise KeyManagementError(f"Key generation failed: {e}")
async def rotate_keys(self, participant_id: str) -> KeyPair:
"""Rotate encryption keys for participant"""
try:
# Get current key pair
current_key = await self.storage.get_key_pair(participant_id)
if not current_key:
raise KeyNotFoundError(f"No existing keys for {participant_id}")
# Generate new key pair
new_key_pair = await self.generate_key_pair(participant_id)
# Log rotation
rotation_log = KeyRotationLog(
participant_id=participant_id,
old_version=current_key.version,
new_version=new_key_pair.version,
rotated_at=datetime.utcnow(),
reason="scheduled_rotation"
)
await self.storage.log_rotation(rotation_log)
# Re-encrypt active transactions (in production)
await self._reencrypt_transactions(participant_id, current_key, new_key_pair)
logger.info(f"Rotated keys for participant: {participant_id}")
return new_key_pair
except Exception as e:
logger.error(f"Failed to rotate keys for {participant_id}: {e}")
raise KeyManagementError(f"Key rotation failed: {e}")
def get_public_key(self, participant_id: str) -> X25519PublicKey:
"""Get public key for participant"""
# Check cache first
if participant_id in self._key_cache:
return self._key_cache[participant_id]["public_key"]
# Load from storage
key_pair = self.storage.get_key_pair_sync(participant_id)
if not key_pair:
raise KeyNotFoundError(f"No keys found for participant: {participant_id}")
# Reconstruct public key
public_key = X25519PublicKey.from_public_bytes(key_pair.public_key)
# Cache it
self._key_cache[participant_id] = {
"public_key": public_key,
"version": key_pair.version
}
return public_key
def get_private_key(self, participant_id: str) -> X25519PrivateKey:
"""Get private key for participant (from secure storage)"""
key_pair = self.storage.get_key_pair_sync(participant_id)
if not key_pair:
raise KeyNotFoundError(f"No keys found for participant: {participant_id}")
# Reconstruct private key
private_key = X25519PrivateKey.from_private_bytes(key_pair.private_key)
return private_key
async def get_audit_key(self) -> X25519PublicKey:
"""Get public audit key for escrow"""
if not self._audit_key or self._should_rotate_audit_key():
await self._rotate_audit_key()
return self._audit_key
async def get_audit_private_key(self, authorization: str) -> X25519PrivateKey:
"""Get private audit key with authorization"""
# Verify authorization
if not await self.verify_audit_authorization(authorization):
raise AccessDeniedError("Invalid audit authorization")
# Load audit key from secure storage
audit_key_data = await self.storage.get_audit_key()
if not audit_key_data:
raise KeyNotFoundError("Audit key not found")
return X25519PrivateKey.from_private_bytes(audit_key_data.private_key)
async def verify_audit_authorization(self, authorization: str) -> bool:
"""Verify audit authorization token"""
try:
# Decode authorization
auth_data = base64.b64decode(authorization).decode()
auth_json = json.loads(auth_data)
# Check expiration
expires_at = datetime.fromisoformat(auth_json["expires_at"])
if datetime.utcnow() > expires_at:
return False
# Verify signature (in production, use proper signature verification)
# For now, just check format
required_fields = ["issuer", "subject", "expires_at", "signature"]
return all(field in auth_json for field in required_fields)
except Exception as e:
logger.error(f"Failed to verify audit authorization: {e}")
return False
async def create_audit_authorization(
self,
issuer: str,
purpose: str,
expires_in_hours: int = 24
) -> str:
"""Create audit authorization token"""
try:
# Create authorization payload
payload = {
"issuer": issuer,
"subject": "audit_access",
"purpose": purpose,
"created_at": datetime.utcnow().isoformat(),
"expires_at": (datetime.utcnow() + timedelta(hours=expires_in_hours)).isoformat(),
"signature": "placeholder" # In production, sign with issuer key
}
# Encode and return
auth_json = json.dumps(payload)
return base64.b64encode(auth_json.encode()).decode()
except Exception as e:
logger.error(f"Failed to create audit authorization: {e}")
raise KeyManagementError(f"Authorization creation failed: {e}")
async def list_participants(self) -> List[str]:
"""List all participants with keys"""
return await self.storage.list_participants()
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
"""Revoke participant's keys"""
try:
# Mark keys as revoked
success = await self.storage.revoke_keys(participant_id, reason)
if success:
# Clear cache
if participant_id in self._key_cache:
del self._key_cache[participant_id]
logger.info(f"Revoked keys for participant: {participant_id}")
return success
except Exception as e:
logger.error(f"Failed to revoke keys for {participant_id}: {e}")
return False
async def _rotate_audit_key(self):
"""Rotate the audit escrow key"""
try:
# Generate new audit key pair
audit_private = X25519PrivateKey.generate()
audit_public = audit_private.public_key()
# Store securely
audit_key_pair = KeyPair(
participant_id="audit",
private_key=audit_private.private_bytes_raw(),
public_key=audit_public.public_bytes_raw(),
algorithm="X25519",
created_at=datetime.utcnow(),
version=1
)
await self.storage.store_audit_key(audit_key_pair)
self._audit_key = audit_public
logger.info("Rotated audit escrow key")
except Exception as e:
logger.error(f"Failed to rotate audit key: {e}")
raise KeyManagementError(f"Audit key rotation failed: {e}")
def _should_rotate_audit_key(self) -> bool:
"""Check if audit key needs rotation"""
# In production, check last rotation time
return self._audit_key is None
async def _reencrypt_transactions(
self,
participant_id: str,
old_key_pair: KeyPair,
new_key_pair: KeyPair
):
"""Re-encrypt active transactions with new key"""
# This would be implemented in production
# For now, just log the action
logger.info(f"Would re-encrypt transactions for {participant_id}")
pass
class KeyStorageBackend:
"""Abstract base for key storage backends"""
async def store_key_pair(self, key_pair: KeyPair) -> bool:
"""Store key pair securely"""
raise NotImplementedError
async def get_key_pair(self, participant_id: str) -> Optional[KeyPair]:
"""Get key pair for participant"""
raise NotImplementedError
def get_key_pair_sync(self, participant_id: str) -> Optional[KeyPair]:
"""Synchronous get key pair"""
raise NotImplementedError
async def store_audit_key(self, key_pair: KeyPair) -> bool:
"""Store audit key pair"""
raise NotImplementedError
async def get_audit_key(self) -> Optional[KeyPair]:
"""Get audit key pair"""
raise NotImplementedError
async def list_participants(self) -> List[str]:
"""List all participants"""
raise NotImplementedError
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
"""Revoke keys for participant"""
raise NotImplementedError
async def log_rotation(self, rotation_log: KeyRotationLog) -> bool:
"""Log key rotation"""
raise NotImplementedError
class FileKeyStorage(KeyStorageBackend):
"""File-based key storage for development"""
def __init__(self, storage_path: str):
self.storage_path = storage_path
os.makedirs(storage_path, exist_ok=True)
async def store_key_pair(self, key_pair: KeyPair) -> bool:
"""Store key pair to file"""
try:
file_path = os.path.join(self.storage_path, f"{key_pair.participant_id}.json")
# Store private key in separate encrypted file
private_path = os.path.join(self.storage_path, f"{key_pair.participant_id}.priv")
# In production, encrypt private key with master key
with open(private_path, "wb") as f:
f.write(key_pair.private_key)
# Store public metadata
metadata = {
"participant_id": key_pair.participant_id,
"public_key": base64.b64encode(key_pair.public_key).decode(),
"algorithm": key_pair.algorithm,
"created_at": key_pair.created_at.isoformat(),
"version": key_pair.version
}
with open(file_path, "w") as f:
json.dump(metadata, f)
return True
except Exception as e:
logger.error(f"Failed to store key pair: {e}")
return False
async def get_key_pair(self, participant_id: str) -> Optional[KeyPair]:
"""Get key pair from file"""
return self.get_key_pair_sync(participant_id)
def get_key_pair_sync(self, participant_id: str) -> Optional[KeyPair]:
"""Synchronous get key pair"""
try:
file_path = os.path.join(self.storage_path, f"{participant_id}.json")
private_path = os.path.join(self.storage_path, f"{participant_id}.priv")
if not os.path.exists(file_path) or not os.path.exists(private_path):
return None
# Load metadata
with open(file_path, "r") as f:
metadata = json.load(f)
# Load private key
with open(private_path, "rb") as f:
private_key = f.read()
return KeyPair(
participant_id=metadata["participant_id"],
private_key=private_key,
public_key=base64.b64decode(metadata["public_key"]),
algorithm=metadata["algorithm"],
created_at=datetime.fromisoformat(metadata["created_at"]),
version=metadata["version"]
)
except Exception as e:
logger.error(f"Failed to get key pair: {e}")
return None
async def store_audit_key(self, key_pair: KeyPair) -> bool:
"""Store audit key"""
audit_path = os.path.join(self.storage_path, "audit.json")
audit_priv_path = os.path.join(self.storage_path, "audit.priv")
try:
# Store private key
with open(audit_priv_path, "wb") as f:
f.write(key_pair.private_key)
# Store metadata
metadata = {
"participant_id": "audit",
"public_key": base64.b64encode(key_pair.public_key).decode(),
"algorithm": key_pair.algorithm,
"created_at": key_pair.created_at.isoformat(),
"version": key_pair.version
}
with open(audit_path, "w") as f:
json.dump(metadata, f)
return True
except Exception as e:
logger.error(f"Failed to store audit key: {e}")
return False
async def get_audit_key(self) -> Optional[KeyPair]:
"""Get audit key"""
return self.get_key_pair_sync("audit")
async def list_participants(self) -> List[str]:
"""List all participants"""
participants = []
for file in os.listdir(self.storage_path):
if file.endswith(".json") and file != "audit.json":
participant_id = file[:-5] # Remove .json
participants.append(participant_id)
return participants
async def revoke_keys(self, participant_id: str, reason: str) -> bool:
"""Revoke keys by deleting files"""
try:
file_path = os.path.join(self.storage_path, f"{participant_id}.json")
private_path = os.path.join(self.storage_path, f"{participant_id}.priv")
# Move to revoked folder instead of deleting
revoked_path = os.path.join(self.storage_path, "revoked")
os.makedirs(revoked_path, exist_ok=True)
if os.path.exists(file_path):
os.rename(file_path, os.path.join(revoked_path, f"{participant_id}.json"))
if os.path.exists(private_path):
os.rename(private_path, os.path.join(revoked_path, f"{participant_id}.priv"))
return True
except Exception as e:
logger.error(f"Failed to revoke keys: {e}")
return False
async def log_rotation(self, rotation_log: KeyRotationLog) -> bool:
"""Log key rotation"""
log_path = os.path.join(self.storage_path, "rotations.log")
try:
with open(log_path, "a") as f:
f.write(json.dumps({
"participant_id": rotation_log.participant_id,
"old_version": rotation_log.old_version,
"new_version": rotation_log.new_version,
"rotated_at": rotation_log.rotated_at.isoformat(),
"reason": rotation_log.reason
}) + "\n")
return True
except Exception as e:
logger.error(f"Failed to log rotation: {e}")
return False
class KeyManagementError(Exception):
"""Base exception for key management errors"""
pass
class KeyNotFoundError(KeyManagementError):
"""Raised when key is not found"""
pass
class AccessDeniedError(KeyManagementError):
"""Raised when access is denied"""
pass

View File

@ -0,0 +1,526 @@
"""
Resource quota enforcement service for multi-tenant AITBC coordinator
"""
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
from sqlalchemy.orm import Session
from sqlalchemy import select, update, and_, func
from contextlib import asynccontextmanager
import redis
import json
from ..models.multitenant import TenantQuota, UsageRecord, Tenant
from ..exceptions import QuotaExceededError, TenantError
from ..middleware.tenant_context import get_current_tenant_id
class QuotaEnforcementService:
"""Service for enforcing tenant resource quotas"""
def __init__(self, db: Session, redis_client: Optional[redis.Redis] = None):
self.db = db
self.redis = redis_client
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
# Cache for quota lookups
self._quota_cache = {}
self._cache_ttl = 300 # 5 minutes
async def check_quota(
self,
resource_type: str,
quantity: float,
tenant_id: Optional[str] = None
) -> bool:
"""Check if tenant has sufficient quota for a resource"""
tenant_id = tenant_id or get_current_tenant_id()
if not tenant_id:
raise TenantError("No tenant context found")
# Get current quota and usage
quota = await self._get_current_quota(tenant_id, resource_type)
if not quota:
# No quota set, check if unlimited plan
tenant = await self._get_tenant(tenant_id)
if tenant and tenant.plan in ["enterprise", "unlimited"]:
return True
raise QuotaExceededError(f"No quota configured for {resource_type}")
# Check if adding quantity would exceed limit
current_usage = await self._get_current_usage(tenant_id, resource_type)
if current_usage + quantity > quota.limit_value:
# Log quota exceeded
self.logger.warning(
f"Quota exceeded for tenant {tenant_id}: "
f"{resource_type} {current_usage + quantity}/{quota.limit_value}"
)
raise QuotaExceededError(
f"Quota exceeded for {resource_type}: "
f"{current_usage + quantity}/{quota.limit_value}"
)
return True
async def consume_quota(
self,
resource_type: str,
quantity: float,
resource_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
tenant_id: Optional[str] = None
) -> UsageRecord:
"""Consume quota and record usage"""
tenant_id = tenant_id or get_current_tenant_id()
if not tenant_id:
raise TenantError("No tenant context found")
# Check quota first
await self.check_quota(resource_type, quantity, tenant_id)
# Create usage record
usage_record = UsageRecord(
tenant_id=tenant_id,
resource_type=resource_type,
resource_id=resource_id,
quantity=quantity,
unit=self._get_unit_for_resource(resource_type),
unit_price=await self._get_unit_price(resource_type),
total_cost=await self._calculate_cost(resource_type, quantity),
currency="USD",
usage_start=datetime.utcnow(),
usage_end=datetime.utcnow(),
metadata=metadata or {}
)
self.db.add(usage_record)
# Update quota usage
await self._update_quota_usage(tenant_id, resource_type, quantity)
# Update cache
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
if self.redis:
current = self.redis.get(cache_key)
if current:
self.redis.incrbyfloat(cache_key, quantity)
self.redis.expire(cache_key, self._cache_ttl)
self.db.commit()
self.logger.info(
f"Consumed quota: tenant={tenant_id}, "
f"resource={resource_type}, quantity={quantity}"
)
return usage_record
async def release_quota(
self,
resource_type: str,
quantity: float,
usage_record_id: str,
tenant_id: Optional[str] = None
):
"""Release quota (e.g., when job completes early)"""
tenant_id = tenant_id or get_current_tenant_id()
if not tenant_id:
raise TenantError("No tenant context found")
# Update usage record
stmt = update(UsageRecord).where(
and_(
UsageRecord.id == usage_record_id,
UsageRecord.tenant_id == tenant_id
)
).values(
quantity=UsageRecord.quantity - quantity,
total_cost=UsageRecord.total_cost - await self._calculate_cost(resource_type, quantity)
)
result = self.db.execute(stmt)
if result.rowcount > 0:
# Update quota usage
await self._update_quota_usage(tenant_id, resource_type, -quantity)
# Update cache
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
if self.redis:
current = self.redis.get(cache_key)
if current:
self.redis.incrbyfloat(cache_key, -quantity)
self.redis.expire(cache_key, self._cache_ttl)
self.db.commit()
self.logger.info(
f"Released quota: tenant={tenant_id}, "
f"resource={resource_type}, quantity={quantity}"
)
async def get_quota_status(
self,
resource_type: Optional[str] = None,
tenant_id: Optional[str] = None
) -> Dict[str, Any]:
"""Get current quota status for a tenant"""
tenant_id = tenant_id or get_current_tenant_id()
if not tenant_id:
raise TenantError("No tenant context found")
# Get all quotas for tenant
stmt = select(TenantQuota).where(
and_(
TenantQuota.tenant_id == tenant_id,
TenantQuota.is_active == True
)
)
if resource_type:
stmt = stmt.where(TenantQuota.resource_type == resource_type)
quotas = self.db.execute(stmt).scalars().all()
status = {
"tenant_id": tenant_id,
"quotas": {},
"summary": {
"total_resources": len(quotas),
"over_limit": 0,
"near_limit": 0
}
}
for quota in quotas:
current_usage = await self._get_current_usage(tenant_id, quota.resource_type)
usage_percent = (current_usage / quota.limit_value) * 100 if quota.limit_value > 0 else 0
quota_status = {
"limit": float(quota.limit_value),
"used": float(current_usage),
"remaining": float(quota.limit_value - current_usage),
"usage_percent": round(usage_percent, 2),
"period": quota.period_type,
"period_start": quota.period_start.isoformat(),
"period_end": quota.period_end.isoformat()
}
status["quotas"][quota.resource_type] = quota_status
# Update summary
if usage_percent >= 100:
status["summary"]["over_limit"] += 1
elif usage_percent >= 80:
status["summary"]["near_limit"] += 1
return status
@asynccontextmanager
async def quota_reservation(
self,
resource_type: str,
quantity: float,
timeout: int = 300, # 5 minutes
tenant_id: Optional[str] = None
):
"""Context manager for temporary quota reservation"""
tenant_id = tenant_id or get_current_tenant_id()
reservation_id = f"reserve:{tenant_id}:{resource_type}:{datetime.utcnow().timestamp()}"
try:
# Reserve quota
await self.check_quota(resource_type, quantity, tenant_id)
# Store reservation in Redis
if self.redis:
reservation_data = {
"tenant_id": tenant_id,
"resource_type": resource_type,
"quantity": quantity,
"created_at": datetime.utcnow().isoformat()
}
self.redis.setex(
f"reservation:{reservation_id}",
timeout,
json.dumps(reservation_data)
)
yield reservation_id
finally:
# Clean up reservation
if self.redis:
self.redis.delete(f"reservation:{reservation_id}")
async def reset_quota_period(self, tenant_id: str, resource_type: str):
"""Reset quota for a new period"""
# Get current quota
stmt = select(TenantQuota).where(
and_(
TenantQuota.tenant_id == tenant_id,
TenantQuota.resource_type == resource_type,
TenantQuota.is_active == True
)
)
quota = self.db.execute(stmt).scalar_one_or_none()
if not quota:
return
# Calculate new period
now = datetime.utcnow()
if quota.period_type == "monthly":
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
period_end = (period_start + timedelta(days=32)).replace(day=1) - timedelta(days=1)
elif quota.period_type == "weekly":
days_since_monday = now.weekday()
period_start = (now - timedelta(days=days_since_monday)).replace(
hour=0, minute=0, second=0, microsecond=0
)
period_end = period_start + timedelta(days=6)
else: # daily
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
period_end = period_start + timedelta(days=1)
# Update quota
quota.period_start = period_start
quota.period_end = period_end
quota.used_value = 0
self.db.commit()
# Clear cache
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
if self.redis:
self.redis.delete(cache_key)
self.logger.info(
f"Reset quota period: tenant={tenant_id}, "
f"resource={resource_type}, period={quota.period_type}"
)
async def get_quota_alerts(self, tenant_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get quota alerts for tenants approaching or exceeding limits"""
tenant_id = tenant_id or get_current_tenant_id()
if not tenant_id:
raise TenantError("No tenant context found")
alerts = []
status = await self.get_quota_status(tenant_id=tenant_id)
for resource_type, quota_status in status["quotas"].items():
usage_percent = quota_status["usage_percent"]
if usage_percent >= 100:
alerts.append({
"severity": "critical",
"resource_type": resource_type,
"message": f"Quota exceeded for {resource_type}",
"usage_percent": usage_percent,
"used": quota_status["used"],
"limit": quota_status["limit"]
})
elif usage_percent >= 90:
alerts.append({
"severity": "warning",
"resource_type": resource_type,
"message": f"Quota almost exceeded for {resource_type}",
"usage_percent": usage_percent,
"used": quota_status["used"],
"limit": quota_status["limit"]
})
elif usage_percent >= 80:
alerts.append({
"severity": "info",
"resource_type": resource_type,
"message": f"Quota usage high for {resource_type}",
"usage_percent": usage_percent,
"used": quota_status["used"],
"limit": quota_status["limit"]
})
return alerts
# Private methods
async def _get_current_quota(self, tenant_id: str, resource_type: str) -> Optional[TenantQuota]:
"""Get current quota for tenant and resource type"""
cache_key = f"quota:{tenant_id}:{resource_type}"
# Check cache first
if self.redis:
cached = self.redis.get(cache_key)
if cached:
quota_data = json.loads(cached)
quota = TenantQuota(**quota_data)
# Check if still valid
if quota.period_end >= datetime.utcnow():
return quota
# Query database
stmt = select(TenantQuota).where(
and_(
TenantQuota.tenant_id == tenant_id,
TenantQuota.resource_type == resource_type,
TenantQuota.is_active == True,
TenantQuota.period_start <= datetime.utcnow(),
TenantQuota.period_end >= datetime.utcnow()
)
)
quota = self.db.execute(stmt).scalar_one_or_none()
# Cache result
if quota and self.redis:
quota_data = {
"id": str(quota.id),
"tenant_id": str(quota.tenant_id),
"resource_type": quota.resource_type,
"limit_value": float(quota.limit_value),
"used_value": float(quota.used_value),
"period_start": quota.period_start.isoformat(),
"period_end": quota.period_end.isoformat()
}
self.redis.setex(
cache_key,
self._cache_ttl,
json.dumps(quota_data)
)
return quota
async def _get_current_usage(self, tenant_id: str, resource_type: str) -> float:
"""Get current usage for tenant and resource type"""
cache_key = f"quota_usage:{tenant_id}:{resource_type}"
# Check cache first
if self.redis:
cached = self.redis.get(cache_key)
if cached:
return float(cached)
# Query database
stmt = select(func.sum(UsageRecord.quantity)).where(
and_(
UsageRecord.tenant_id == tenant_id,
UsageRecord.resource_type == resource_type,
UsageRecord.usage_start >= func.date_trunc('month', func.current_date())
)
)
result = self.db.execute(stmt).scalar()
usage = float(result) if result else 0.0
# Cache result
if self.redis:
self.redis.setex(cache_key, self._cache_ttl, str(usage))
return usage
async def _update_quota_usage(self, tenant_id: str, resource_type: str, quantity: float):
"""Update quota usage in database"""
stmt = update(TenantQuota).where(
and_(
TenantQuota.tenant_id == tenant_id,
TenantQuota.resource_type == resource_type,
TenantQuota.is_active == True
)
).values(
used_value=TenantQuota.used_value + quantity
)
self.db.execute(stmt)
async def _get_tenant(self, tenant_id: str) -> Optional[Tenant]:
"""Get tenant by ID"""
stmt = select(Tenant).where(Tenant.id == tenant_id)
return self.db.execute(stmt).scalar_one_or_none()
def _get_unit_for_resource(self, resource_type: str) -> str:
"""Get unit for resource type"""
unit_map = {
"gpu_hours": "hours",
"storage_gb": "gb",
"api_calls": "calls",
"bandwidth_gb": "gb",
"compute_hours": "hours"
}
return unit_map.get(resource_type, "units")
async def _get_unit_price(self, resource_type: str) -> float:
"""Get unit price for resource type"""
# In a real implementation, this would come from a pricing table
price_map = {
"gpu_hours": 0.50, # $0.50 per hour
"storage_gb": 0.02, # $0.02 per GB per month
"api_calls": 0.0001, # $0.0001 per call
"bandwidth_gb": 0.01, # $0.01 per GB
"compute_hours": 0.30 # $0.30 per hour
}
return price_map.get(resource_type, 0.0)
async def _calculate_cost(self, resource_type: str, quantity: float) -> float:
"""Calculate cost for resource usage"""
unit_price = await self._get_unit_price(resource_type)
return unit_price * quantity
class QuotaMiddleware:
"""Middleware to enforce quotas on API endpoints"""
def __init__(self, quota_service: QuotaEnforcementService):
self.quota_service = quota_service
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
# Resource costs per endpoint
self.endpoint_costs = {
"/api/v1/jobs": {"resource": "compute_hours", "cost": 0.1},
"/api/v1/models": {"resource": "storage_gb", "cost": 0.1},
"/api/v1/data": {"resource": "storage_gb", "cost": 0.05},
"/api/v1/analytics": {"resource": "api_calls", "cost": 1}
}
async def check_endpoint_quota(self, endpoint: str, estimated_cost: float = 0):
"""Check if endpoint call is within quota"""
resource_config = self.endpoint_costs.get(endpoint)
if not resource_config:
return # No quota check for this endpoint
try:
await self.quota_service.check_quota(
resource_config["resource"],
resource_config["cost"] + estimated_cost
)
except QuotaExceededError as e:
self.logger.warning(f"Quota exceeded for endpoint {endpoint}: {e}")
raise
async def consume_endpoint_quota(self, endpoint: str, actual_cost: float = 0):
"""Consume quota after endpoint execution"""
resource_config = self.endpoint_costs.get(endpoint)
if not resource_config:
return
try:
await self.quota_service.consume_quota(
resource_config["resource"],
resource_config["cost"] + actual_cost
)
except Exception as e:
self.logger.error(f"Failed to consume quota for {endpoint}: {e}")
# Don't fail the request, just log the error

View File

@ -10,6 +10,7 @@ from sqlmodel import Session
from ..config import settings
from ..domain import Job, JobReceipt
from .zk_proofs import zk_proof_service
class ReceiptService:
@ -24,12 +25,13 @@ class ReceiptService:
attest_bytes = bytes.fromhex(settings.receipt_attestation_key_hex)
self._attestation_signer = ReceiptSigner(attest_bytes)
def create_receipt(
async def create_receipt(
self,
job: Job,
miner_id: str,
job_result: Dict[str, Any] | None,
result_metrics: Dict[str, Any] | None,
privacy_level: Optional[str] = None,
) -> Dict[str, Any] | None:
if self._signer is None:
return None
@ -67,6 +69,32 @@ class ReceiptService:
attestation_payload.pop("attestations", None)
attestation_payload.pop("signature", None)
payload["attestations"].append(self._attestation_signer.sign(attestation_payload))
# Generate ZK proof if privacy is requested
if privacy_level and zk_proof_service.is_enabled():
try:
# Create receipt model for ZK proof generation
receipt_model = JobReceipt(
job_id=job.id,
receipt_id=payload["receipt_id"],
payload=payload
)
# Generate ZK proof
zk_proof = await zk_proof_service.generate_receipt_proof(
receipt=receipt_model,
job_result=job_result or {},
privacy_level=privacy_level
)
if zk_proof:
payload["zk_proof"] = zk_proof
payload["privacy_level"] = privacy_level
except Exception as e:
# Log error but don't fail receipt creation
print(f"Failed to generate ZK proof: {e}")
receipt_row = JobReceipt(job_id=job.id, receipt_id=payload["receipt_id"], payload=payload)
self.session.add(receipt_row)
return payload

View File

@ -0,0 +1,690 @@
"""
Tenant management service for multi-tenant AITBC coordinator
"""
import secrets
import hashlib
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from sqlalchemy import select, update, delete, and_, or_, func
from ..models.multitenant import (
Tenant, TenantUser, TenantQuota, TenantApiKey,
TenantAuditLog, TenantStatus
)
from ..database import get_db
from ..exceptions import TenantError, QuotaExceededError
class TenantManagementService:
"""Service for managing tenants in multi-tenant environment"""
def __init__(self, db: Session):
self.db = db
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
async def create_tenant(
self,
name: str,
contact_email: str,
plan: str = "trial",
domain: Optional[str] = None,
settings: Optional[Dict[str, Any]] = None,
features: Optional[Dict[str, Any]] = None
) -> Tenant:
"""Create a new tenant"""
# Generate unique slug
slug = self._generate_slug(name)
if await self._tenant_exists(slug=slug):
raise TenantError(f"Tenant with slug '{slug}' already exists")
# Check domain uniqueness if provided
if domain and await self._tenant_exists(domain=domain):
raise TenantError(f"Domain '{domain}' is already in use")
# Create tenant
tenant = Tenant(
name=name,
slug=slug,
domain=domain,
contact_email=contact_email,
plan=plan,
status=TenantStatus.PENDING.value,
settings=settings or {},
features=features or {}
)
self.db.add(tenant)
self.db.flush()
# Create default quotas
await self._create_default_quotas(tenant.id, plan)
# Log creation
await self._log_audit_event(
tenant_id=tenant.id,
event_type="tenant_created",
event_category="lifecycle",
actor_id="system",
actor_type="system",
resource_type="tenant",
resource_id=str(tenant.id),
new_values={"name": name, "plan": plan}
)
self.db.commit()
self.logger.info(f"Created tenant: {tenant.id} ({name})")
return tenant
async def get_tenant(self, tenant_id: str) -> Optional[Tenant]:
"""Get tenant by ID"""
stmt = select(Tenant).where(Tenant.id == tenant_id)
return self.db.execute(stmt).scalar_one_or_none()
async def get_tenant_by_slug(self, slug: str) -> Optional[Tenant]:
"""Get tenant by slug"""
stmt = select(Tenant).where(Tenant.slug == slug)
return self.db.execute(stmt).scalar_one_or_none()
async def get_tenant_by_domain(self, domain: str) -> Optional[Tenant]:
"""Get tenant by domain"""
stmt = select(Tenant).where(Tenant.domain == domain)
return self.db.execute(stmt).scalar_one_or_none()
async def update_tenant(
self,
tenant_id: str,
updates: Dict[str, Any],
actor_id: str,
actor_type: str = "user"
) -> Tenant:
"""Update tenant information"""
tenant = await self.get_tenant(tenant_id)
if not tenant:
raise TenantError(f"Tenant not found: {tenant_id}")
# Store old values for audit
old_values = {
"name": tenant.name,
"contact_email": tenant.contact_email,
"billing_email": tenant.billing_email,
"settings": tenant.settings,
"features": tenant.features
}
# Apply updates
for key, value in updates.items():
if hasattr(tenant, key):
setattr(tenant, key, value)
tenant.updated_at = datetime.utcnow()
# Log update
await self._log_audit_event(
tenant_id=tenant.id,
event_type="tenant_updated",
event_category="lifecycle",
actor_id=actor_id,
actor_type=actor_type,
resource_type="tenant",
resource_id=str(tenant.id),
old_values=old_values,
new_values=updates
)
self.db.commit()
self.logger.info(f"Updated tenant: {tenant_id}")
return tenant
async def activate_tenant(
self,
tenant_id: str,
actor_id: str,
actor_type: str = "user"
) -> Tenant:
"""Activate a tenant"""
tenant = await self.get_tenant(tenant_id)
if not tenant:
raise TenantError(f"Tenant not found: {tenant_id}")
if tenant.status == TenantStatus.ACTIVE.value:
return tenant
tenant.status = TenantStatus.ACTIVE.value
tenant.activated_at = datetime.utcnow()
tenant.updated_at = datetime.utcnow()
# Log activation
await self._log_audit_event(
tenant_id=tenant.id,
event_type="tenant_activated",
event_category="lifecycle",
actor_id=actor_id,
actor_type=actor_type,
resource_type="tenant",
resource_id=str(tenant.id),
old_values={"status": "pending"},
new_values={"status": "active"}
)
self.db.commit()
self.logger.info(f"Activated tenant: {tenant_id}")
return tenant
async def deactivate_tenant(
self,
tenant_id: str,
reason: Optional[str] = None,
actor_id: str = "system",
actor_type: str = "system"
) -> Tenant:
"""Deactivate a tenant"""
tenant = await self.get_tenant(tenant_id)
if not tenant:
raise TenantError(f"Tenant not found: {tenant_id}")
if tenant.status == TenantStatus.INACTIVE.value:
return tenant
old_status = tenant.status
tenant.status = TenantStatus.INACTIVE.value
tenant.deactivated_at = datetime.utcnow()
tenant.updated_at = datetime.utcnow()
# Revoke all API keys
await self._revoke_all_api_keys(tenant_id)
# Log deactivation
await self._log_audit_event(
tenant_id=tenant.id,
event_type="tenant_deactivated",
event_category="lifecycle",
actor_id=actor_id,
actor_type=actor_type,
resource_type="tenant",
resource_id=str(tenant.id),
old_values={"status": old_status},
new_values={"status": "inactive", "reason": reason}
)
self.db.commit()
self.logger.info(f"Deactivated tenant: {tenant_id} (reason: {reason})")
return tenant
async def suspend_tenant(
self,
tenant_id: str,
reason: Optional[str] = None,
actor_id: str = "system",
actor_type: str = "system"
) -> Tenant:
"""Suspend a tenant temporarily"""
tenant = await self.get_tenant(tenant_id)
if not tenant:
raise TenantError(f"Tenant not found: {tenant_id}")
old_status = tenant.status
tenant.status = TenantStatus.SUSPENDED.value
tenant.updated_at = datetime.utcnow()
# Log suspension
await self._log_audit_event(
tenant_id=tenant.id,
event_type="tenant_suspended",
event_category="lifecycle",
actor_id=actor_id,
actor_type=actor_type,
resource_type="tenant",
resource_id=str(tenant.id),
old_values={"status": old_status},
new_values={"status": "suspended", "reason": reason}
)
self.db.commit()
self.logger.warning(f"Suspended tenant: {tenant_id} (reason: {reason})")
return tenant
async def add_user_to_tenant(
self,
tenant_id: str,
user_id: str,
role: str = "member",
permissions: Optional[List[str]] = None,
actor_id: str = "system"
) -> TenantUser:
"""Add a user to a tenant"""
# Check if user already exists
stmt = select(TenantUser).where(
and_(TenantUser.tenant_id == tenant_id, TenantUser.user_id == user_id)
)
existing = self.db.execute(stmt).scalar_one_or_none()
if existing:
raise TenantError(f"User {user_id} already belongs to tenant {tenant_id}")
# Create tenant user
tenant_user = TenantUser(
tenant_id=tenant_id,
user_id=user_id,
role=role,
permissions=permissions or [],
joined_at=datetime.utcnow()
)
self.db.add(tenant_user)
# Log addition
await self._log_audit_event(
tenant_id=tenant_id,
event_type="user_added",
event_category="access",
actor_id=actor_id,
actor_type="system",
resource_type="tenant_user",
resource_id=str(tenant_user.id),
new_values={"user_id": user_id, "role": role}
)
self.db.commit()
self.logger.info(f"Added user {user_id} to tenant {tenant_id}")
return tenant_user
async def remove_user_from_tenant(
self,
tenant_id: str,
user_id: str,
actor_id: str = "system"
) -> bool:
"""Remove a user from a tenant"""
stmt = select(TenantUser).where(
and_(TenantUser.tenant_id == tenant_id, TenantUser.user_id == user_id)
)
tenant_user = self.db.execute(stmt).scalar_one_or_none()
if not tenant_user:
return False
# Store for audit
old_values = {
"user_id": user_id,
"role": tenant_user.role,
"permissions": tenant_user.permissions
}
self.db.delete(tenant_user)
# Log removal
await self._log_audit_event(
tenant_id=tenant_id,
event_type="user_removed",
event_category="access",
actor_id=actor_id,
actor_type="system",
resource_type="tenant_user",
resource_id=str(tenant_user.id),
old_values=old_values
)
self.db.commit()
self.logger.info(f"Removed user {user_id} from tenant {tenant_id}")
return True
async def create_api_key(
self,
tenant_id: str,
name: str,
permissions: Optional[List[str]] = None,
rate_limit: Optional[int] = None,
allowed_ips: Optional[List[str]] = None,
expires_at: Optional[datetime] = None,
created_by: str = "system"
) -> TenantApiKey:
"""Create a new API key for a tenant"""
# Generate secure key
key_id = f"ak_{secrets.token_urlsafe(16)}"
api_key = f"ask_{secrets.token_urlsafe(32)}"
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
key_prefix = api_key[:8]
# Create API key record
api_key_record = TenantApiKey(
tenant_id=tenant_id,
key_id=key_id,
key_hash=key_hash,
key_prefix=key_prefix,
name=name,
permissions=permissions or [],
rate_limit=rate_limit,
allowed_ips=allowed_ips,
expires_at=expires_at,
created_by=created_by
)
self.db.add(api_key_record)
self.db.flush()
# Log creation
await self._log_audit_event(
tenant_id=tenant_id,
event_type="api_key_created",
event_category="security",
actor_id=created_by,
actor_type="user",
resource_type="api_key",
resource_id=str(api_key_record.id),
new_values={
"key_id": key_id,
"name": name,
"permissions": permissions,
"rate_limit": rate_limit
}
)
self.db.commit()
self.logger.info(f"Created API key {key_id} for tenant {tenant_id}")
# Return the key (only time it's shown)
api_key_record.api_key = api_key
return api_key_record
async def revoke_api_key(
self,
tenant_id: str,
key_id: str,
actor_id: str = "system"
) -> bool:
"""Revoke an API key"""
stmt = select(TenantApiKey).where(
and_(
TenantApiKey.tenant_id == tenant_id,
TenantApiKey.key_id == key_id,
TenantApiKey.is_active == True
)
)
api_key = self.db.execute(stmt).scalar_one_or_none()
if not api_key:
return False
api_key.is_active = False
api_key.revoked_at = datetime.utcnow()
# Log revocation
await self._log_audit_event(
tenant_id=tenant_id,
event_type="api_key_revoked",
event_category="security",
actor_id=actor_id,
actor_type="user",
resource_type="api_key",
resource_id=str(api_key.id),
old_values={"key_id": key_id, "is_active": True}
)
self.db.commit()
self.logger.info(f"Revoked API key {key_id} for tenant {tenant_id}")
return True
async def get_tenant_usage(
self,
tenant_id: str,
resource_type: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> Dict[str, Any]:
"""Get usage statistics for a tenant"""
from ..models.multitenant import UsageRecord
# Default to last 30 days
if not end_date:
end_date = datetime.utcnow()
if not start_date:
start_date = end_date - timedelta(days=30)
# Build query
stmt = select(
UsageRecord.resource_type,
func.sum(UsageRecord.quantity).label("total_quantity"),
func.sum(UsageRecord.total_cost).label("total_cost"),
func.count(UsageRecord.id).label("record_count")
).where(
and_(
UsageRecord.tenant_id == tenant_id,
UsageRecord.usage_start >= start_date,
UsageRecord.usage_end <= end_date
)
)
if resource_type:
stmt = stmt.where(UsageRecord.resource_type == resource_type)
stmt = stmt.group_by(UsageRecord.resource_type)
results = self.db.execute(stmt).all()
# Format results
usage = {
"period": {
"start": start_date.isoformat(),
"end": end_date.isoformat()
},
"by_resource": {}
}
for result in results:
usage["by_resource"][result.resource_type] = {
"quantity": float(result.total_quantity),
"cost": float(result.total_cost),
"records": result.record_count
}
return usage
async def get_tenant_quotas(self, tenant_id: str) -> List[TenantQuota]:
"""Get all quotas for a tenant"""
stmt = select(TenantQuota).where(
and_(
TenantQuota.tenant_id == tenant_id,
TenantQuota.is_active == True
)
)
return self.db.execute(stmt).scalars().all()
async def check_quota(
self,
tenant_id: str,
resource_type: str,
quantity: float
) -> bool:
"""Check if tenant has sufficient quota for a resource"""
# Get current quota
stmt = select(TenantQuota).where(
and_(
TenantQuota.tenant_id == tenant_id,
TenantQuota.resource_type == resource_type,
TenantQuota.is_active == True,
TenantQuota.period_start <= datetime.utcnow(),
TenantQuota.period_end >= datetime.utcnow()
)
)
quota = self.db.execute(stmt).scalar_one_or_none()
if not quota:
# No quota set, deny by default
return False
# Check if usage + quantity exceeds limit
if quota.used_value + quantity > quota.limit_value:
raise QuotaExceededError(
f"Quota exceeded for {resource_type}: "
f"{quota.used_value + quantity}/{quota.limit_value}"
)
return True
async def update_quota_usage(
self,
tenant_id: str,
resource_type: str,
quantity: float
):
"""Update quota usage for a tenant"""
# Get current quota
stmt = select(TenantQuota).where(
and_(
TenantQuota.tenant_id == tenant_id,
TenantQuota.resource_type == resource_type,
TenantQuota.is_active == True,
TenantQuota.period_start <= datetime.utcnow(),
TenantQuota.period_end >= datetime.utcnow()
)
)
quota = self.db.execute(stmt).scalar_one_or_none()
if quota:
quota.used_value += quantity
self.db.commit()
# Private methods
def _generate_slug(self, name: str) -> str:
"""Generate a unique slug from name"""
import re
# Convert to lowercase and replace spaces with hyphens
base = re.sub(r'[^a-z0-9]+', '-', name.lower()).strip('-')
# Add random suffix for uniqueness
suffix = secrets.token_urlsafe(4)
return f"{base}-{suffix}"
async def _tenant_exists(self, slug: Optional[str] = None, domain: Optional[str] = None) -> bool:
"""Check if tenant exists by slug or domain"""
conditions = []
if slug:
conditions.append(Tenant.slug == slug)
if domain:
conditions.append(Tenant.domain == domain)
if not conditions:
return False
stmt = select(func.count(Tenant.id)).where(or_(*conditions))
count = self.db.execute(stmt).scalar()
return count > 0
async def _create_default_quotas(self, tenant_id: str, plan: str):
"""Create default quotas based on plan"""
# Define quota templates by plan
quota_templates = {
"trial": {
"gpu_hours": {"limit": 100, "period": "monthly"},
"storage_gb": {"limit": 10, "period": "monthly"},
"api_calls": {"limit": 10000, "period": "monthly"}
},
"basic": {
"gpu_hours": {"limit": 500, "period": "monthly"},
"storage_gb": {"limit": 100, "period": "monthly"},
"api_calls": {"limit": 100000, "period": "monthly"}
},
"pro": {
"gpu_hours": {"limit": 2000, "period": "monthly"},
"storage_gb": {"limit": 1000, "period": "monthly"},
"api_calls": {"limit": 1000000, "period": "monthly"}
},
"enterprise": {
"gpu_hours": {"limit": 10000, "period": "monthly"},
"storage_gb": {"limit": 10000, "period": "monthly"},
"api_calls": {"limit": 10000000, "period": "monthly"}
}
}
quotas = quota_templates.get(plan, quota_templates["trial"])
# Create quota records
now = datetime.utcnow()
period_end = now.replace(day=1) + timedelta(days=32) # Next month
period_end = period_end.replace(day=1) - timedelta(days=1) # Last day of current month
for resource_type, config in quotas.items():
quota = TenantQuota(
tenant_id=tenant_id,
resource_type=resource_type,
limit_value=config["limit"],
used_value=0,
period_type=config["period"],
period_start=now,
period_end=period_end
)
self.db.add(quota)
async def _revoke_all_api_keys(self, tenant_id: str):
"""Revoke all API keys for a tenant"""
stmt = update(TenantApiKey).where(
and_(
TenantApiKey.tenant_id == tenant_id,
TenantApiKey.is_active == True
)
).values(
is_active=False,
revoked_at=datetime.utcnow()
)
self.db.execute(stmt)
async def _log_audit_event(
self,
tenant_id: str,
event_type: str,
event_category: str,
actor_id: str,
actor_type: str,
resource_type: str,
resource_id: Optional[str] = None,
old_values: Optional[Dict[str, Any]] = None,
new_values: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""Log an audit event"""
audit_log = TenantAuditLog(
tenant_id=tenant_id,
event_type=event_type,
event_category=event_category,
actor_id=actor_id,
actor_type=actor_type,
resource_type=resource_type,
resource_id=resource_id,
old_values=old_values,
new_values=new_values,
metadata=metadata
)
self.db.add(audit_log)

View File

@ -0,0 +1,654 @@
"""
Usage tracking and billing metrics service for multi-tenant AITBC coordinator
"""
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import select, update, and_, or_, func, desc
from dataclasses import dataclass, asdict
from decimal import Decimal
import asyncio
from concurrent.futures import ThreadPoolExecutor
from ..models.multitenant import (
UsageRecord, Invoice, Tenant, TenantQuota,
TenantMetric
)
from ..exceptions import BillingError, TenantError
from ..middleware.tenant_context import get_current_tenant_id
@dataclass
class UsageSummary:
"""Usage summary for billing period"""
tenant_id: str
period_start: datetime
period_end: datetime
resources: Dict[str, Dict[str, Any]]
total_cost: Decimal
currency: str
@dataclass
class BillingEvent:
"""Billing event for processing"""
tenant_id: str
event_type: str # usage, quota_adjustment, credit, charge
resource_type: Optional[str]
quantity: Decimal
unit_price: Decimal
total_amount: Decimal
currency: str
timestamp: datetime
metadata: Dict[str, Any]
class UsageTrackingService:
"""Service for tracking usage and generating billing metrics"""
def __init__(self, db: Session):
self.db = db
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
self.executor = ThreadPoolExecutor(max_workers=4)
# Pricing configuration
self.pricing_config = {
"gpu_hours": {"unit_price": Decimal("0.50"), "tiered": True},
"storage_gb": {"unit_price": Decimal("0.02"), "tiered": True},
"api_calls": {"unit_price": Decimal("0.0001"), "tiered": False},
"bandwidth_gb": {"unit_price": Decimal("0.01"), "tiered": False},
"compute_hours": {"unit_price": Decimal("0.30"), "tiered": True}
}
# Tier pricing thresholds
self.tier_thresholds = {
"gpu_hours": [
{"min": 0, "max": 100, "multiplier": 1.0},
{"min": 101, "max": 500, "multiplier": 0.9},
{"min": 501, "max": 2000, "multiplier": 0.8},
{"min": 2001, "max": None, "multiplier": 0.7}
],
"storage_gb": [
{"min": 0, "max": 100, "multiplier": 1.0},
{"min": 101, "max": 1000, "multiplier": 0.85},
{"min": 1001, "max": 10000, "multiplier": 0.75},
{"min": 10001, "max": None, "multiplier": 0.65}
],
"compute_hours": [
{"min": 0, "max": 200, "multiplier": 1.0},
{"min": 201, "max": 1000, "multiplier": 0.9},
{"min": 1001, "max": 5000, "multiplier": 0.8},
{"min": 5001, "max": None, "multiplier": 0.7}
]
}
async def record_usage(
self,
tenant_id: str,
resource_type: str,
quantity: Decimal,
unit_price: Optional[Decimal] = None,
job_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> UsageRecord:
"""Record usage for billing"""
# Calculate unit price if not provided
if not unit_price:
unit_price = await self._calculate_unit_price(resource_type, quantity)
# Calculate total cost
total_cost = unit_price * quantity
# Create usage record
usage_record = UsageRecord(
tenant_id=tenant_id,
resource_type=resource_type,
quantity=quantity,
unit=self._get_unit_for_resource(resource_type),
unit_price=unit_price,
total_cost=total_cost,
currency="USD",
usage_start=datetime.utcnow(),
usage_end=datetime.utcnow(),
job_id=job_id,
metadata=metadata or {}
)
self.db.add(usage_record)
self.db.commit()
# Emit billing event
await self._emit_billing_event(BillingEvent(
tenant_id=tenant_id,
event_type="usage",
resource_type=resource_type,
quantity=quantity,
unit_price=unit_price,
total_amount=total_cost,
currency="USD",
timestamp=datetime.utcnow(),
metadata=metadata or {}
))
self.logger.info(
f"Recorded usage: tenant={tenant_id}, "
f"resource={resource_type}, quantity={quantity}, cost={total_cost}"
)
return usage_record
async def get_usage_summary(
self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
resource_type: Optional[str] = None
) -> UsageSummary:
"""Get usage summary for a billing period"""
# Build query
stmt = select(
UsageRecord.resource_type,
func.sum(UsageRecord.quantity).label("total_quantity"),
func.sum(UsageRecord.total_cost).label("total_cost"),
func.count(UsageRecord.id).label("record_count"),
func.avg(UsageRecord.unit_price).label("avg_unit_price")
).where(
and_(
UsageRecord.tenant_id == tenant_id,
UsageRecord.usage_start >= start_date,
UsageRecord.usage_end <= end_date
)
)
if resource_type:
stmt = stmt.where(UsageRecord.resource_type == resource_type)
stmt = stmt.group_by(UsageRecord.resource_type)
results = self.db.execute(stmt).all()
# Build summary
resources = {}
total_cost = Decimal("0")
for result in results:
resources[result.resource_type] = {
"quantity": float(result.total_quantity),
"cost": float(result.total_cost),
"records": result.record_count,
"avg_unit_price": float(result.avg_unit_price)
}
total_cost += Decimal(str(result.total_cost))
return UsageSummary(
tenant_id=tenant_id,
period_start=start_date,
period_end=end_date,
resources=resources,
total_cost=total_cost,
currency="USD"
)
async def generate_invoice(
self,
tenant_id: str,
period_start: datetime,
period_end: datetime,
due_days: int = 30
) -> Invoice:
"""Generate invoice for billing period"""
# Check if invoice already exists
existing = await self._get_existing_invoice(tenant_id, period_start, period_end)
if existing:
raise BillingError(f"Invoice already exists for period {period_start} to {period_end}")
# Get usage summary
summary = await self.get_usage_summary(tenant_id, period_start, period_end)
# Generate invoice number
invoice_number = await self._generate_invoice_number(tenant_id)
# Calculate line items
line_items = []
subtotal = Decimal("0")
for resource_type, usage in summary.resources.items():
line_item = {
"description": f"{resource_type.replace('_', ' ').title()} Usage",
"quantity": usage["quantity"],
"unit_price": usage["avg_unit_price"],
"amount": usage["cost"]
}
line_items.append(line_item)
subtotal += Decimal(str(usage["cost"]))
# Calculate tax (example: 10% for digital services)
tax_rate = Decimal("0.10")
tax_amount = subtotal * tax_rate
total_amount = subtotal + tax_amount
# Create invoice
invoice = Invoice(
tenant_id=tenant_id,
invoice_number=invoice_number,
status="draft",
period_start=period_start,
period_end=period_end,
due_date=period_end + timedelta(days=due_days),
subtotal=subtotal,
tax_amount=tax_amount,
total_amount=total_amount,
currency="USD",
line_items=line_items
)
self.db.add(invoice)
self.db.commit()
self.logger.info(
f"Generated invoice {invoice_number} for tenant {tenant_id}: "
f"${total_amount}"
)
return invoice
async def get_billing_metrics(
self,
tenant_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> Dict[str, Any]:
"""Get billing metrics and analytics"""
# Default to last 30 days
if not end_date:
end_date = datetime.utcnow()
if not start_date:
start_date = end_date - timedelta(days=30)
# Build base query
base_conditions = [
UsageRecord.usage_start >= start_date,
UsageRecord.usage_end <= end_date
]
if tenant_id:
base_conditions.append(UsageRecord.tenant_id == tenant_id)
# Total usage and cost
stmt = select(
func.sum(UsageRecord.quantity).label("total_quantity"),
func.sum(UsageRecord.total_cost).label("total_cost"),
func.count(UsageRecord.id).label("total_records"),
func.count(func.distinct(UsageRecord.tenant_id)).label("active_tenants")
).where(and_(*base_conditions))
totals = self.db.execute(stmt).first()
# Usage by resource type
stmt = select(
UsageRecord.resource_type,
func.sum(UsageRecord.quantity).label("quantity"),
func.sum(UsageRecord.total_cost).label("cost")
).where(and_(*base_conditions)).group_by(UsageRecord.resource_type)
by_resource = self.db.execute(stmt).all()
# Top tenants by usage
if not tenant_id:
stmt = select(
UsageRecord.tenant_id,
func.sum(UsageRecord.total_cost).label("total_cost")
).where(and_(*base_conditions)).group_by(
UsageRecord.tenant_id
).order_by(desc("total_cost")).limit(10)
top_tenants = self.db.execute(stmt).all()
else:
top_tenants = []
# Daily usage trend
stmt = select(
func.date(UsageRecord.usage_start).label("date"),
func.sum(UsageRecord.total_cost).label("daily_cost")
).where(and_(*base_conditions)).group_by(
func.date(UsageRecord.usage_start)
).order_by("date")
daily_trend = self.db.execute(stmt).all()
# Assemble metrics
metrics = {
"period": {
"start": start_date.isoformat(),
"end": end_date.isoformat()
},
"totals": {
"quantity": float(totals.total_quantity or 0),
"cost": float(totals.total_cost or 0),
"records": totals.total_records or 0,
"active_tenants": totals.active_tenants or 0
},
"by_resource": {
r.resource_type: {
"quantity": float(r.quantity),
"cost": float(r.cost)
}
for r in by_resource
},
"top_tenants": [
{
"tenant_id": str(t.tenant_id),
"cost": float(t.total_cost)
}
for t in top_tenants
],
"daily_trend": [
{
"date": d.date.isoformat(),
"cost": float(d.daily_cost)
}
for d in daily_trend
]
}
return metrics
async def process_billing_events(self, events: List[BillingEvent]) -> bool:
"""Process batch of billing events"""
try:
for event in events:
if event.event_type == "usage":
# Already recorded in record_usage
continue
elif event.event_type == "credit":
await self._apply_credit(event)
elif event.event_type == "charge":
await self._apply_charge(event)
elif event.event_type == "quota_adjustment":
await self._adjust_quota(event)
return True
except Exception as e:
self.logger.error(f"Failed to process billing events: {e}")
return False
async def export_usage_data(
self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
format: str = "csv"
) -> str:
"""Export usage data in specified format"""
# Get usage records
stmt = select(UsageRecord).where(
and_(
UsageRecord.tenant_id == tenant_id,
UsageRecord.usage_start >= start_date,
UsageRecord.usage_end <= end_date
)
).order_by(UsageRecord.usage_start)
records = self.db.execute(stmt).scalars().all()
if format == "csv":
return await self._export_csv(records)
elif format == "json":
return await self._export_json(records)
else:
raise BillingError(f"Unsupported export format: {format}")
# Private methods
async def _calculate_unit_price(
self,
resource_type: str,
quantity: Decimal
) -> Decimal:
"""Calculate unit price with tiered pricing"""
config = self.pricing_config.get(resource_type)
if not config:
return Decimal("0")
base_price = config["unit_price"]
if not config.get("tiered", False):
return base_price
# Find applicable tier
tiers = self.tier_thresholds.get(resource_type, [])
quantity_float = float(quantity)
for tier in tiers:
if (tier["min"] is None or quantity_float >= tier["min"]) and \
(tier["max"] is None or quantity_float <= tier["max"]):
return base_price * Decimal(str(tier["multiplier"]))
# Default to highest tier
return base_price * Decimal("0.5")
def _get_unit_for_resource(self, resource_type: str) -> str:
"""Get unit for resource type"""
unit_map = {
"gpu_hours": "hours",
"storage_gb": "gb",
"api_calls": "calls",
"bandwidth_gb": "gb",
"compute_hours": "hours"
}
return unit_map.get(resource_type, "units")
async def _emit_billing_event(self, event: BillingEvent):
"""Emit billing event for processing"""
# In a real implementation, this would publish to a message queue
# For now, we'll just log it
self.logger.debug(f"Emitting billing event: {event}")
async def _get_existing_invoice(
self,
tenant_id: str,
period_start: datetime,
period_end: datetime
) -> Optional[Invoice]:
"""Check if invoice already exists for period"""
stmt = select(Invoice).where(
and_(
Invoice.tenant_id == tenant_id,
Invoice.period_start == period_start,
Invoice.period_end == period_end
)
)
return self.db.execute(stmt).scalar_one_or_none()
async def _generate_invoice_number(self, tenant_id: str) -> str:
"""Generate unique invoice number"""
# Get tenant info
stmt = select(Tenant).where(Tenant.id == tenant_id)
tenant = self.db.execute(stmt).scalar_one_or_none()
if not tenant:
raise TenantError(f"Tenant not found: {tenant_id}")
# Generate number: INV-{tenant.slug}-{YYYYMMDD}-{seq}
date_str = datetime.utcnow().strftime("%Y%m%d")
# Get sequence for today
seq_key = f"invoice_seq:{tenant_id}:{date_str}"
# In a real implementation, use Redis or sequence table
# For now, use a simple counter
stmt = select(func.count(Invoice.id)).where(
and_(
Invoice.tenant_id == tenant_id,
func.date(Invoice.created_at) == func.current_date()
)
)
seq = self.db.execute(stmt).scalar() + 1
return f"INV-{tenant.slug}-{date_str}-{seq:04d}"
async def _apply_credit(self, event: BillingEvent):
"""Apply credit to tenant account"""
# TODO: Implement credit application
pass
async def _apply_charge(self, event: BillingEvent):
"""Apply charge to tenant account"""
# TODO: Implement charge application
pass
async def _adjust_quota(self, event: BillingEvent):
"""Adjust quota based on billing event"""
# TODO: Implement quota adjustment
pass
async def _export_csv(self, records: List[UsageRecord]) -> str:
"""Export records to CSV"""
import csv
import io
output = io.StringIO()
writer = csv.writer(output)
# Header
writer.writerow([
"Timestamp", "Resource Type", "Quantity", "Unit",
"Unit Price", "Total Cost", "Currency", "Job ID"
])
# Data rows
for record in records:
writer.writerow([
record.usage_start.isoformat(),
record.resource_type,
record.quantity,
record.unit,
record.unit_price,
record.total_cost,
record.currency,
record.job_id or ""
])
return output.getvalue()
async def _export_json(self, records: List[UsageRecord]) -> str:
"""Export records to JSON"""
import json
data = []
for record in records:
data.append({
"timestamp": record.usage_start.isoformat(),
"resource_type": record.resource_type,
"quantity": float(record.quantity),
"unit": record.unit,
"unit_price": float(record.unit_price),
"total_cost": float(record.total_cost),
"currency": record.currency,
"job_id": record.job_id,
"metadata": record.metadata
})
return json.dumps(data, indent=2)
class BillingScheduler:
"""Scheduler for automated billing processes"""
def __init__(self, usage_service: UsageTrackingService):
self.usage_service = usage_service
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
self.running = False
async def start(self):
"""Start billing scheduler"""
if self.running:
return
self.running = True
self.logger.info("Billing scheduler started")
# Schedule daily tasks
asyncio.create_task(self._daily_tasks())
# Schedule monthly invoicing
asyncio.create_task(self._monthly_invoicing())
async def stop(self):
"""Stop billing scheduler"""
self.running = False
self.logger.info("Billing scheduler stopped")
async def _daily_tasks(self):
"""Run daily billing tasks"""
while self.running:
try:
# Reset quotas for new periods
await self._reset_daily_quotas()
# Process pending billing events
await self._process_pending_events()
# Wait until next day
now = datetime.utcnow()
next_day = (now + timedelta(days=1)).replace(
hour=0, minute=0, second=0, microsecond=0
)
sleep_seconds = (next_day - now).total_seconds()
await asyncio.sleep(sleep_seconds)
except Exception as e:
self.logger.error(f"Error in daily tasks: {e}")
await asyncio.sleep(3600) # Retry in 1 hour
async def _monthly_invoicing(self):
"""Generate monthly invoices"""
while self.running:
try:
# Wait until first day of month
now = datetime.utcnow()
if now.day != 1:
next_month = now.replace(day=1) + timedelta(days=32)
next_month = next_month.replace(day=1)
sleep_seconds = (next_month - now).total_seconds()
await asyncio.sleep(sleep_seconds)
continue
# Generate invoices for all active tenants
await self._generate_monthly_invoices()
# Wait until next month
next_month = now.replace(day=1) + timedelta(days=32)
next_month = next_month.replace(day=1)
sleep_seconds = (next_month - now).total_seconds()
await asyncio.sleep(sleep_seconds)
except Exception as e:
self.logger.error(f"Error in monthly invoicing: {e}")
await asyncio.sleep(86400) # Retry in 1 day
async def _reset_daily_quotas(self):
"""Reset daily quotas"""
# TODO: Implement daily quota reset
pass
async def _process_pending_events(self):
"""Process pending billing events"""
# TODO: Implement event processing
pass
async def _generate_monthly_invoices(self):
"""Generate invoices for all tenants"""
# TODO: Implement monthly invoice generation
pass

View File

@ -0,0 +1,269 @@
"""
ZK Proof generation service for privacy-preserving receipt attestation
"""
import asyncio
import json
import subprocess
from pathlib import Path
from typing import Dict, Any, Optional, List
import tempfile
import os
from ..models import Receipt, JobResult
from ..settings import settings
from ..logging import get_logger
logger = get_logger(__name__)
class ZKProofService:
"""Service for generating zero-knowledge proofs for receipts"""
def __init__(self):
self.circuits_dir = Path(__file__).parent.parent.parent.parent / "apps" / "zk-circuits"
self.zkey_path = self.circuits_dir / "receipt_0001.zkey"
self.wasm_path = self.circuits_dir / "receipt.wasm"
self.vkey_path = self.circuits_dir / "verification_key.json"
# Verify circuit files exist
if not all(p.exists() for p in [self.zkey_path, self.wasm_path, self.vkey_path]):
logger.warning("ZK circuit files not found. Proof generation disabled.")
self.enabled = False
else:
self.enabled = True
async def generate_receipt_proof(
self,
receipt: Receipt,
job_result: JobResult,
privacy_level: str = "basic"
) -> Optional[Dict[str, Any]]:
"""Generate a ZK proof for a receipt"""
if not self.enabled:
logger.warning("ZK proof generation not available")
return None
try:
# Prepare circuit inputs based on privacy level
inputs = await self._prepare_inputs(receipt, job_result, privacy_level)
# Generate proof using snarkjs
proof_data = await self._generate_proof(inputs)
# Return proof with verification data
return {
"proof": proof_data["proof"],
"public_signals": proof_data["publicSignals"],
"privacy_level": privacy_level,
"circuit_hash": await self._get_circuit_hash()
}
except Exception as e:
logger.error(f"Failed to generate ZK proof: {e}")
return None
async def _prepare_inputs(
self,
receipt: Receipt,
job_result: JobResult,
privacy_level: str
) -> Dict[str, Any]:
"""Prepare circuit inputs based on privacy level"""
if privacy_level == "basic":
# Hide computation details, reveal settlement amount
return {
"data": [
str(receipt.job_id),
str(receipt.miner_id),
str(job_result.result_hash),
str(receipt.pricing.rate)
],
"hash": await self._hash_receipt(receipt)
}
elif privacy_level == "enhanced":
# Hide all amounts, prove correctness
return {
"settlementAmount": receipt.settlement_amount,
"timestamp": receipt.timestamp,
"receipt": self._serialize_receipt(receipt),
"computationResult": job_result.result_hash,
"pricingRate": receipt.pricing.rate,
"minerReward": receipt.miner_reward,
"coordinatorFee": receipt.coordinator_fee
}
else:
raise ValueError(f"Unknown privacy level: {privacy_level}")
async def _hash_receipt(self, receipt: Receipt) -> str:
"""Hash receipt for public verification"""
# In a real implementation, use Poseidon or the same hash as circuit
import hashlib
receipt_data = {
"job_id": receipt.job_id,
"miner_id": receipt.miner_id,
"timestamp": receipt.timestamp,
"pricing": receipt.pricing.dict()
}
receipt_str = json.dumps(receipt_data, sort_keys=True)
return hashlib.sha256(receipt_str.encode()).hexdigest()
def _serialize_receipt(self, receipt: Receipt) -> List[str]:
"""Serialize receipt for circuit input"""
# Convert receipt to field elements for circuit
return [
str(receipt.job_id)[:32], # Truncate for field size
str(receipt.miner_id)[:32],
str(receipt.timestamp)[:32],
str(receipt.settlement_amount)[:32],
str(receipt.miner_reward)[:32],
str(receipt.coordinator_fee)[:32],
"0", "0" # Padding
]
async def _generate_proof(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Generate proof using snarkjs"""
# Write inputs to temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(inputs, f)
inputs_file = f.name
try:
# Create Node.js script for proof generation
script = f"""
const snarkjs = require('snarkjs');
const fs = require('fs');
async function main() {{
try {{
// Load inputs
const inputs = JSON.parse(fs.readFileSync('{inputs_file}', 'utf8'));
// Load circuit
const wasm = fs.readFileSync('{self.wasm_path}');
const zkey = fs.readFileSync('{self.zkey_path}');
// Calculate witness
const {{ witness }} = await snarkjs.wtns.calculate(inputs, wasm, wasm);
// Generate proof
const {{ proof, publicSignals }} = await snarkjs.groth16.prove(zkey, witness);
// Output result
console.log(JSON.stringify({{ proof, publicSignals }}));
}} catch (error) {{
console.error('Error:', error);
process.exit(1);
}}
}}
main();
"""
# Write script to temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
f.write(script)
script_file = f.name
try:
# Run script
result = subprocess.run(
["node", script_file],
capture_output=True,
text=True,
cwd=str(self.circuits_dir)
)
if result.returncode != 0:
raise Exception(f"Proof generation failed: {result.stderr}")
# Parse result
return json.loads(result.stdout)
finally:
os.unlink(script_file)
finally:
os.unlink(inputs_file)
async def _get_circuit_hash(self) -> str:
"""Get hash of circuit for verification"""
# In a real implementation, return the hash of the circuit
# This ensures the proof is for the correct circuit version
return "0x1234567890abcdef"
async def verify_proof(
self,
proof: Dict[str, Any],
public_signals: List[str]
) -> bool:
"""Verify a ZK proof"""
if not self.enabled:
return False
try:
# Load verification key
with open(self.vkey_path) as f:
vkey = json.load(f)
# Create verification script
script = f"""
const snarkjs = require('snarkjs');
async function main() {{
try {{
const vKey = {json.dumps(vkey)};
const proof = {json.dumps(proof)};
const publicSignals = {json.dumps(public_signals)};
const verified = await snarkjs.groth16.verify(vKey, publicSignals, proof);
console.log(verified);
}} catch (error) {{
console.error('Error:', error);
process.exit(1);
}}
}}
main();
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
f.write(script)
script_file = f.name
try:
result = subprocess.run(
["node", script_file],
capture_output=True,
text=True,
cwd=str(self.circuits_dir)
)
if result.returncode != 0:
logger.error(f"Proof verification failed: {result.stderr}")
return False
return result.stdout.strip() == "true"
finally:
os.unlink(script_file)
except Exception as e:
logger.error(f"Failed to verify proof: {e}")
return False
def is_enabled(self) -> bool:
"""Check if ZK proof generation is available"""
return self.enabled
# Global instance
zk_proof_service = ZKProofService()

View File

@ -0,0 +1,505 @@
"""
Tests for confidential transaction functionality
"""
import pytest
import asyncio
import json
import base64
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, AsyncMock
from app.models import (
ConfidentialTransaction,
ConfidentialTransactionCreate,
ConfidentialAccessRequest,
KeyRegistrationRequest
)
from app.services.encryption import EncryptionService, EncryptedData
from app.services.key_management import KeyManager, FileKeyStorage
from app.services.access_control import AccessController, PolicyStore
from app.services.audit_logging import AuditLogger
class TestEncryptionService:
"""Test encryption service functionality"""
@pytest.fixture
def key_manager(self):
"""Create test key manager"""
storage = FileKeyStorage("/tmp/test_keys")
return KeyManager(storage)
@pytest.fixture
def encryption_service(self, key_manager):
"""Create test encryption service"""
return EncryptionService(key_manager)
@pytest.mark.asyncio
async def test_encrypt_decrypt_success(self, encryption_service, key_manager):
"""Test successful encryption and decryption"""
# Generate test keys
await key_manager.generate_key_pair("client-123")
await key_manager.generate_key_pair("miner-456")
# Test data
data = {
"amount": "1000",
"pricing": {"rate": "0.1", "currency": "AITBC"},
"settlement_details": {"method": "crypto", "address": "0x123..."}
}
participants = ["client-123", "miner-456"]
# Encrypt data
encrypted = encryption_service.encrypt(
data=data,
participants=participants,
include_audit=True
)
assert encrypted.ciphertext is not None
assert len(encrypted.encrypted_keys) == 3 # 2 participants + audit
assert "client-123" in encrypted.encrypted_keys
assert "miner-456" in encrypted.encrypted_keys
assert "audit" in encrypted.encrypted_keys
# Decrypt for client
decrypted = encryption_service.decrypt(
encrypted_data=encrypted,
participant_id="client-123",
purpose="settlement"
)
assert decrypted == data
# Decrypt for miner
decrypted_miner = encryption_service.decrypt(
encrypted_data=encrypted,
participant_id="miner-456",
purpose="settlement"
)
assert decrypted_miner == data
@pytest.mark.asyncio
async def test_audit_decrypt(self, encryption_service, key_manager):
"""Test audit decryption"""
# Generate keys
await key_manager.generate_key_pair("client-123")
# Create audit authorization
auth = await key_manager.create_audit_authorization(
issuer="regulator",
purpose="compliance"
)
# Encrypt data
data = {"amount": "1000", "secret": "hidden"}
encrypted = encryption_service.encrypt(
data=data,
participants=["client-123"],
include_audit=True
)
# Decrypt with audit key
decrypted = encryption_service.audit_decrypt(
encrypted_data=encrypted,
audit_authorization=auth,
purpose="compliance"
)
assert decrypted == data
def test_encrypt_no_participants(self, encryption_service):
"""Test encryption with no participants"""
data = {"test": "data"}
with pytest.raises(Exception):
encryption_service.encrypt(
data=data,
participants=[],
include_audit=True
)
class TestKeyManager:
"""Test key management functionality"""
@pytest.fixture
def key_storage(self, tmp_path):
"""Create test key storage"""
return FileKeyStorage(str(tmp_path / "keys"))
@pytest.fixture
def key_manager(self, key_storage):
"""Create test key manager"""
return KeyManager(key_storage)
@pytest.mark.asyncio
async def test_generate_key_pair(self, key_manager):
"""Test key pair generation"""
key_pair = await key_manager.generate_key_pair("test-participant")
assert key_pair.participant_id == "test-participant"
assert key_pair.algorithm == "X25519"
assert key_pair.private_key is not None
assert key_pair.public_key is not None
assert key_pair.version == 1
@pytest.mark.asyncio
async def test_key_rotation(self, key_manager):
"""Test key rotation"""
# Generate initial key
initial_key = await key_manager.generate_key_pair("test-participant")
initial_version = initial_key.version
# Rotate keys
new_key = await key_manager.rotate_keys("test-participant")
assert new_key.participant_id == "test-participant"
assert new_key.version > initial_version
assert new_key.private_key != initial_key.private_key
assert new_key.public_key != initial_key.public_key
def test_get_public_key(self, key_manager):
"""Test retrieving public key"""
# This would need a key to be pre-generated
with pytest.raises(Exception):
key_manager.get_public_key("nonexistent")
class TestAccessController:
"""Test access control functionality"""
@pytest.fixture
def policy_store(self):
"""Create test policy store"""
return PolicyStore()
@pytest.fixture
def access_controller(self, policy_store):
"""Create test access controller"""
return AccessController(policy_store)
def test_client_access_own_data(self, access_controller):
"""Test client accessing own transaction"""
request = ConfidentialAccessRequest(
transaction_id="tx-123",
requester="client-456",
purpose="settlement"
)
# Should allow access
assert access_controller.verify_access(request) is True
def test_miner_access_assigned_data(self, access_controller):
"""Test miner accessing assigned transaction"""
request = ConfidentialAccessRequest(
transaction_id="tx-123",
requester="miner-789",
purpose="settlement"
)
# Should allow access
assert access_controller.verify_access(request) is True
def test_unauthorized_access(self, access_controller):
"""Test unauthorized access attempt"""
request = ConfidentialAccessRequest(
transaction_id="tx-123",
requester="unauthorized-user",
purpose="settlement"
)
# Should deny access
assert access_controller.verify_access(request) is False
def test_audit_access(self, access_controller):
"""Test auditor access"""
request = ConfidentialAccessRequest(
transaction_id="tx-123",
requester="auditor-001",
purpose="compliance"
)
# Should allow access during business hours
assert access_controller.verify_access(request) is True
class TestAuditLogger:
"""Test audit logging functionality"""
@pytest.fixture
def audit_logger(self, tmp_path):
"""Create test audit logger"""
return AuditLogger(log_dir=str(tmp_path / "audit"))
def test_log_access(self, audit_logger):
"""Test logging access events"""
# Log access event
audit_logger.log_access(
participant_id="client-456",
transaction_id="tx-123",
action="decrypt",
outcome="success",
ip_address="192.168.1.1",
user_agent="test-client"
)
# Wait for background writer
import time
time.sleep(0.1)
# Query logs
events = audit_logger.query_logs(
participant_id="client-456",
limit=10
)
assert len(events) > 0
assert events[0].participant_id == "client-456"
assert events[0].transaction_id == "tx-123"
assert events[0].action == "decrypt"
assert events[0].outcome == "success"
def test_log_key_operation(self, audit_logger):
"""Test logging key operations"""
audit_logger.log_key_operation(
participant_id="miner-789",
operation="rotate",
key_version=2,
outcome="success"
)
# Wait for background writer
import time
time.sleep(0.1)
# Query logs
events = audit_logger.query_logs(
event_type="key_operation",
limit=10
)
assert len(events) > 0
assert events[0].event_type == "key_operation"
assert events[0].action == "rotate"
assert events[0].details["key_version"] == 2
def test_export_logs(self, audit_logger):
"""Test log export functionality"""
# Add some test events
audit_logger.log_access(
participant_id="test-user",
transaction_id="tx-456",
action="test",
outcome="success"
)
# Wait for background writer
import time
time.sleep(0.1)
# Export logs
export_data = audit_logger.export_logs(
start_time=datetime.utcnow() - timedelta(hours=1),
end_time=datetime.utcnow(),
format="json"
)
# Parse export
export = json.loads(export_data)
assert "export_metadata" in export
assert "events" in export
assert export["export_metadata"]["event_count"] > 0
class TestConfidentialTransactionAPI:
"""Test confidential transaction API endpoints"""
@pytest.mark.asyncio
async def test_create_confidential_transaction(self):
"""Test creating a confidential transaction"""
from app.routers.confidential import create_confidential_transaction
request = ConfidentialTransactionCreate(
job_id="job-123",
amount="1000",
pricing={"rate": "0.1"},
confidential=True,
participants=["client-456", "miner-789"]
)
# Mock API key
with patch('app.routers.confidential.get_api_key', return_value="test-key"):
response = await create_confidential_transaction(request)
assert response.transaction_id.startswith("ctx-")
assert response.job_id == "job-123"
assert response.confidential is True
assert response.has_encrypted_data is True
assert response.amount is None # Should be encrypted
@pytest.mark.asyncio
async def test_access_confidential_data(self):
"""Test accessing confidential transaction data"""
from app.routers.confidential import access_confidential_data
request = ConfidentialAccessRequest(
transaction_id="tx-123",
requester="client-456",
purpose="settlement"
)
# Mock dependencies
with patch('app.routers.confidential.get_api_key', return_value="test-key"), \
patch('app.routers.confidential.get_access_controller') as mock_ac, \
patch('app.routers.confidential.get_encryption_service') as mock_es:
# Mock access control
mock_ac.return_value.verify_access.return_value = True
# Mock encryption service
mock_es.return_value.decrypt.return_value = {
"amount": "1000",
"pricing": {"rate": "0.1"}
}
response = await access_confidential_data(request, "tx-123")
assert response.success is True
assert response.data is not None
assert response.data["amount"] == "1000"
@pytest.mark.asyncio
async def test_register_key(self):
"""Test key registration"""
from app.routers.confidential import register_encryption_key
# Generate test key pair
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
private_key = X25519PrivateKey.generate()
public_key = private_key.public_key()
public_key_bytes = public_key.public_bytes_raw()
request = KeyRegistrationRequest(
participant_id="test-participant",
public_key=base64.b64encode(public_key_bytes).decode()
)
with patch('app.routers.confidential.get_api_key', return_value="test-key"):
response = await register_encryption_key(request)
assert response.success is True
assert response.participant_id == "test-participant"
assert response.key_version >= 1
# Integration Tests
class TestConfidentialTransactionFlow:
"""End-to-end tests for confidential transaction flow"""
@pytest.mark.asyncio
async def test_full_confidential_flow(self):
"""Test complete confidential transaction flow"""
# Setup
key_storage = FileKeyStorage("/tmp/integration_keys")
key_manager = KeyManager(key_storage)
encryption_service = EncryptionService(key_manager)
access_controller = AccessController(PolicyStore())
# 1. Generate keys for participants
await key_manager.generate_key_pair("client-123")
await key_manager.generate_key_pair("miner-456")
# 2. Create confidential transaction
transaction_data = {
"amount": "1000",
"pricing": {"rate": "0.1", "currency": "AITBC"},
"settlement_details": {"method": "crypto"}
}
participants = ["client-123", "miner-456"]
# 3. Encrypt data
encrypted = encryption_service.encrypt(
data=transaction_data,
participants=participants,
include_audit=True
)
# 4. Store transaction (mock)
transaction = ConfidentialTransaction(
transaction_id="ctx-test-123",
job_id="job-456",
timestamp=datetime.utcnow(),
status="created",
confidential=True,
participants=participants,
encrypted_data=encrypted.to_dict()["ciphertext"],
encrypted_keys=encrypted.to_dict()["encrypted_keys"],
algorithm=encrypted.algorithm
)
# 5. Client accesses data
client_request = ConfidentialAccessRequest(
transaction_id=transaction.transaction_id,
requester="client-123",
purpose="settlement"
)
assert access_controller.verify_access(client_request) is True
client_data = encryption_service.decrypt(
encrypted_data=encrypted,
participant_id="client-123",
purpose="settlement"
)
assert client_data == transaction_data
# 6. Miner accesses data
miner_request = ConfidentialAccessRequest(
transaction_id=transaction.transaction_id,
requester="miner-456",
purpose="settlement"
)
assert access_controller.verify_access(miner_request) is True
miner_data = encryption_service.decrypt(
encrypted_data=encrypted,
participant_id="miner-456",
purpose="settlement"
)
assert miner_data == transaction_data
# 7. Unauthorized access denied
unauthorized_request = ConfidentialAccessRequest(
transaction_id=transaction.transaction_id,
requester="unauthorized",
purpose="settlement"
)
assert access_controller.verify_access(unauthorized_request) is False
# 8. Audit access
audit_auth = await key_manager.create_audit_authorization(
issuer="regulator",
purpose="compliance"
)
audit_data = encryption_service.audit_decrypt(
encrypted_data=encrypted,
audit_authorization=audit_auth,
purpose="compliance"
)
assert audit_data == transaction_data
# Cleanup
import shutil
shutil.rmtree("/tmp/integration_keys", ignore_errors=True)

View File

@ -0,0 +1,402 @@
"""
Tests for ZK proof generation and verification
"""
import pytest
import json
from unittest.mock import Mock, patch, AsyncMock
from pathlib import Path
from app.services.zk_proofs import ZKProofService
from app.models import JobReceipt, Job, JobResult
from app.domain import ReceiptPayload
class TestZKProofService:
"""Test cases for ZK proof service"""
@pytest.fixture
def zk_service(self):
"""Create ZK proof service instance"""
with patch('app.services.zk_proofs.settings'):
service = ZKProofService()
return service
@pytest.fixture
def sample_job(self):
"""Create sample job for testing"""
return Job(
id="test-job-123",
client_id="client-456",
payload={"type": "test"},
constraints={},
requested_at=None,
completed=True
)
@pytest.fixture
def sample_job_result(self):
"""Create sample job result"""
return {
"result": "test-result",
"result_hash": "0x1234567890abcdef",
"units": 100,
"unit_type": "gpu_seconds",
"metrics": {"execution_time": 5.0}
}
@pytest.fixture
def sample_receipt(self, sample_job):
"""Create sample receipt"""
payload = ReceiptPayload(
version="1.0",
receipt_id="receipt-789",
job_id=sample_job.id,
provider="miner-001",
client=sample_job.client_id,
units=100,
unit_type="gpu_seconds",
price="0.1",
started_at=1640995200,
completed_at=1640995800,
metadata={}
)
return JobReceipt(
job_id=sample_job.id,
receipt_id=payload.receipt_id,
payload=payload.dict()
)
def test_service_initialization_with_files(self):
"""Test service initialization when circuit files exist"""
with patch('app.services.zk_proofs.Path') as mock_path:
# Mock file existence
mock_path.return_value.exists.return_value = True
service = ZKProofService()
assert service.enabled is True
def test_service_initialization_without_files(self):
"""Test service initialization when circuit files are missing"""
with patch('app.services.zk_proofs.Path') as mock_path:
# Mock file non-existence
mock_path.return_value.exists.return_value = False
service = ZKProofService()
assert service.enabled is False
@pytest.mark.asyncio
async def test_generate_proof_basic_privacy(self, zk_service, sample_receipt, sample_job_result):
"""Test generating proof with basic privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
# Mock subprocess calls
with patch('subprocess.run') as mock_run:
# Mock successful proof generation
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = json.dumps({
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
"publicSignals": ["0x1234", "1000", "1640995800"]
})
# Generate proof
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="basic"
)
assert proof is not None
assert "proof" in proof
assert "public_signals" in proof
assert proof["privacy_level"] == "basic"
assert "circuit_hash" in proof
@pytest.mark.asyncio
async def test_generate_proof_enhanced_privacy(self, zk_service, sample_receipt, sample_job_result):
"""Test generating proof with enhanced privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run:
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = json.dumps({
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
"publicSignals": ["1000", "1640995800"]
})
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="enhanced"
)
assert proof is not None
assert proof["privacy_level"] == "enhanced"
@pytest.mark.asyncio
async def test_generate_proof_service_disabled(self, zk_service, sample_receipt, sample_job_result):
"""Test proof generation when service is disabled"""
zk_service.enabled = False
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="basic"
)
assert proof is None
@pytest.mark.asyncio
async def test_generate_proof_invalid_privacy_level(self, zk_service, sample_receipt, sample_job_result):
"""Test proof generation with invalid privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with pytest.raises(ValueError, match="Unknown privacy level"):
await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="invalid"
)
@pytest.mark.asyncio
async def test_verify_proof_success(self, zk_service):
"""Test successful proof verification"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run, \
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = "true"
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
)
assert result is True
@pytest.mark.asyncio
async def test_verify_proof_failure(self, zk_service):
"""Test proof verification failure"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run, \
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
mock_run.return_value.returncode = 1
mock_run.return_value.stderr = "Verification failed"
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
)
assert result is False
@pytest.mark.asyncio
async def test_verify_proof_service_disabled(self, zk_service):
"""Test proof verification when service is disabled"""
zk_service.enabled = False
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
)
assert result is False
def test_hash_receipt(self, zk_service, sample_receipt):
"""Test receipt hashing"""
receipt_hash = zk_service._hash_receipt(sample_receipt)
assert isinstance(receipt_hash, str)
assert len(receipt_hash) == 64 # SHA256 hex length
assert all(c in '0123456789abcdef' for c in receipt_hash)
def test_serialize_receipt(self, zk_service, sample_receipt):
"""Test receipt serialization for circuit"""
serialized = zk_service._serialize_receipt(sample_receipt)
assert isinstance(serialized, list)
assert len(serialized) == 8
assert all(isinstance(x, str) for x in serialized)
class TestZKProofIntegration:
"""Integration tests for ZK proof system"""
@pytest.mark.asyncio
async def test_receipt_creation_with_zk_proof(self):
"""Test receipt creation with ZK proof generation"""
from app.services.receipts import ReceiptService
from sqlmodel import Session
# Create mock session
session = Mock(spec=Session)
# Create receipt service
receipt_service = ReceiptService(session)
# Create sample job
job = Job(
id="test-job-123",
client_id="client-456",
payload={"type": "test"},
constraints={},
requested_at=None,
completed=True
)
# Mock ZK proof service
with patch('app.services.receipts.zk_proof_service') as mock_zk:
mock_zk.is_enabled.return_value = True
mock_zk.generate_receipt_proof = AsyncMock(return_value={
"proof": {"a": ["1", "2"]},
"public_signals": ["0x1234"],
"privacy_level": "basic"
})
# Create receipt with privacy
receipt = await receipt_service.create_receipt(
job=job,
miner_id="miner-001",
job_result={"result": "test"},
result_metrics={"units": 100},
privacy_level="basic"
)
assert receipt is not None
assert "zk_proof" in receipt
assert receipt["privacy_level"] == "basic"
@pytest.mark.asyncio
async def test_settlement_with_zk_proof(self):
"""Test cross-chain settlement with ZK proof"""
from aitbc.settlement.hooks import SettlementHook
from aitbc.settlement.manager import BridgeManager
# Create mock bridge manager
bridge_manager = Mock(spec=BridgeManager)
# Create settlement hook
settlement_hook = SettlementHook(bridge_manager)
# Create sample job with ZK proof
job = Job(
id="test-job-123",
client_id="client-456",
payload={"type": "test"},
constraints={},
requested_at=None,
completed=True,
target_chain=2
)
# Create receipt with ZK proof
receipt_payload = {
"version": "1.0",
"receipt_id": "receipt-789",
"job_id": job.id,
"provider": "miner-001",
"client": job.client_id,
"zk_proof": {
"proof": {"a": ["1", "2"]},
"public_signals": ["0x1234"]
}
}
job.receipt = JobReceipt(
job_id=job.id,
receipt_id=receipt_payload["receipt_id"],
payload=receipt_payload
)
# Test settlement message creation
message = await settlement_hook._create_settlement_message(
job,
options={"use_zk_proof": True, "privacy_level": "basic"}
)
assert message.zk_proof is not None
assert message.privacy_level == "basic"
# Helper function for mocking file operations
def mock_open(read_data=""):
"""Mock open function for file operations"""
from unittest.mock import mock_open
return mock_open(read_data=read_data)
# Benchmark tests
class TestZKProofPerformance:
"""Performance benchmarks for ZK proof operations"""
@pytest.mark.asyncio
async def test_proof_generation_time(self):
"""Benchmark proof generation time"""
import time
if not Path("apps/zk-circuits/receipt.wasm").exists():
pytest.skip("ZK circuits not built")
service = ZKProofService()
if not service.enabled:
pytest.skip("ZK service not enabled")
# Create test data
receipt = JobReceipt(
job_id="benchmark-job",
receipt_id="benchmark-receipt",
payload={"test": "data"}
)
job_result = {"result": "benchmark"}
# Measure proof generation time
start_time = time.time()
proof = await service.generate_receipt_proof(
receipt=receipt,
job_result=job_result,
privacy_level="basic"
)
end_time = time.time()
generation_time = end_time - start_time
assert proof is not None
assert generation_time < 30 # Should complete within 30 seconds
print(f"Proof generation time: {generation_time:.2f} seconds")
@pytest.mark.asyncio
async def test_proof_verification_time(self):
"""Benchmark proof verification time"""
import time
service = ZKProofService()
if not service.enabled:
pytest.skip("ZK service not enabled")
# Create test proof
proof = {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]}
public_signals = ["0x1234", "1000"]
# Measure verification time
start_time = time.time()
result = await service.verify_proof(proof, public_signals)
end_time = time.time()
verification_time = end_time - start_time
assert isinstance(result, bool)
assert verification_time < 1 # Should complete within 1 second
print(f"Proof verification time: {verification_time:.3f} seconds")