feat: add transaction hash search to blockchain explorer and cleanup settlement storage
Blockchain Explorer: - Add transaction hash search support (64-char hex pattern validation) - Fetch and display transaction details in modal (hash, type, from/to, amount, fee, block) - Fix regex escape sequence in block height validation - Update search placeholder text to mention both search types - Add blank lines between function definitions for PEP 8 compliance Settlement Storage: - Add timedelta import for future
This commit is contained in:
@@ -3,30 +3,26 @@ Storage layer for cross-chain settlements
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
|
||||
from .bridges.base import (
|
||||
SettlementMessage,
|
||||
SettlementResult,
|
||||
BridgeStatus
|
||||
)
|
||||
from .bridges.base import SettlementMessage, SettlementResult, BridgeStatus
|
||||
|
||||
|
||||
class SettlementStorage:
|
||||
"""Storage interface for settlement data"""
|
||||
|
||||
|
||||
def __init__(self, db_connection):
|
||||
self.db = db_connection
|
||||
|
||||
|
||||
async def store_settlement(
|
||||
self,
|
||||
message_id: str,
|
||||
message: SettlementMessage,
|
||||
bridge_name: str,
|
||||
status: BridgeStatus
|
||||
status: BridgeStatus,
|
||||
) -> None:
|
||||
"""Store a new settlement record"""
|
||||
query = """
|
||||
@@ -38,93 +34,96 @@ class SettlementStorage:
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
|
||||
)
|
||||
"""
|
||||
|
||||
await self.db.execute(query, (
|
||||
message_id,
|
||||
message.job_id,
|
||||
message.source_chain_id,
|
||||
message.target_chain_id,
|
||||
message.receipt_hash,
|
||||
json.dumps(message.proof_data),
|
||||
message.payment_amount,
|
||||
message.payment_token,
|
||||
message.nonce,
|
||||
message.signature,
|
||||
bridge_name,
|
||||
status.value,
|
||||
message.created_at or datetime.utcnow()
|
||||
))
|
||||
|
||||
|
||||
await self.db.execute(
|
||||
query,
|
||||
(
|
||||
message_id,
|
||||
message.job_id,
|
||||
message.source_chain_id,
|
||||
message.target_chain_id,
|
||||
message.receipt_hash,
|
||||
json.dumps(message.proof_data),
|
||||
message.payment_amount,
|
||||
message.payment_token,
|
||||
message.nonce,
|
||||
message.signature,
|
||||
bridge_name,
|
||||
status.value,
|
||||
message.created_at or datetime.utcnow(),
|
||||
),
|
||||
)
|
||||
|
||||
async def update_settlement(
|
||||
self,
|
||||
message_id: str,
|
||||
status: Optional[BridgeStatus] = None,
|
||||
transaction_hash: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
completed_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None,
|
||||
) -> None:
|
||||
"""Update settlement record"""
|
||||
updates = []
|
||||
params = []
|
||||
param_count = 1
|
||||
|
||||
|
||||
if status is not None:
|
||||
updates.append(f"status = ${param_count}")
|
||||
params.append(status.value)
|
||||
param_count += 1
|
||||
|
||||
|
||||
if transaction_hash is not None:
|
||||
updates.append(f"transaction_hash = ${param_count}")
|
||||
params.append(transaction_hash)
|
||||
param_count += 1
|
||||
|
||||
|
||||
if error_message is not None:
|
||||
updates.append(f"error_message = ${param_count}")
|
||||
params.append(error_message)
|
||||
param_count += 1
|
||||
|
||||
|
||||
if completed_at is not None:
|
||||
updates.append(f"completed_at = ${param_count}")
|
||||
params.append(completed_at)
|
||||
param_count += 1
|
||||
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
|
||||
updates.append(f"updated_at = ${param_count}")
|
||||
params.append(datetime.utcnow())
|
||||
param_count += 1
|
||||
|
||||
|
||||
params.append(message_id)
|
||||
|
||||
|
||||
query = f"""
|
||||
UPDATE settlements
|
||||
SET {', '.join(updates)}
|
||||
SET {", ".join(updates)}
|
||||
WHERE message_id = ${param_count}
|
||||
"""
|
||||
|
||||
|
||||
await self.db.execute(query, params)
|
||||
|
||||
|
||||
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get settlement by message ID"""
|
||||
query = """
|
||||
SELECT * FROM settlements WHERE message_id = $1
|
||||
"""
|
||||
|
||||
|
||||
result = await self.db.fetchrow(query, message_id)
|
||||
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
|
||||
# Convert to dict
|
||||
settlement = dict(result)
|
||||
|
||||
|
||||
# Parse JSON fields
|
||||
if settlement['proof_data']:
|
||||
settlement['proof_data'] = json.loads(settlement['proof_data'])
|
||||
|
||||
if settlement["proof_data"]:
|
||||
settlement["proof_data"] = json.loads(settlement["proof_data"])
|
||||
|
||||
return settlement
|
||||
|
||||
|
||||
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all settlements for a job"""
|
||||
query = """
|
||||
@@ -132,65 +131,67 @@ class SettlementStorage:
|
||||
WHERE job_id = $1
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
|
||||
|
||||
results = await self.db.fetch(query, job_id)
|
||||
|
||||
|
||||
settlements = []
|
||||
for result in results:
|
||||
settlement = dict(result)
|
||||
if settlement['proof_data']:
|
||||
settlement['proof_data'] = json.loads(settlement['proof_data'])
|
||||
if settlement["proof_data"]:
|
||||
settlement["proof_data"] = json.loads(settlement["proof_data"])
|
||||
settlements.append(settlement)
|
||||
|
||||
|
||||
return settlements
|
||||
|
||||
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
|
||||
async def get_pending_settlements(
|
||||
self, bridge_name: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all pending settlements"""
|
||||
query = """
|
||||
SELECT * FROM settlements
|
||||
WHERE status = 'pending' OR status = 'in_progress'
|
||||
"""
|
||||
params = []
|
||||
|
||||
|
||||
if bridge_name:
|
||||
query += " AND bridge_name = $1"
|
||||
params.append(bridge_name)
|
||||
|
||||
|
||||
query += " ORDER BY created_at ASC"
|
||||
|
||||
|
||||
results = await self.db.fetch(query, *params)
|
||||
|
||||
|
||||
settlements = []
|
||||
for result in results:
|
||||
settlement = dict(result)
|
||||
if settlement['proof_data']:
|
||||
settlement['proof_data'] = json.loads(settlement['proof_data'])
|
||||
if settlement["proof_data"]:
|
||||
settlement["proof_data"] = json.loads(settlement["proof_data"])
|
||||
settlements.append(settlement)
|
||||
|
||||
|
||||
return settlements
|
||||
|
||||
|
||||
async def get_settlement_stats(
|
||||
self,
|
||||
bridge_name: Optional[str] = None,
|
||||
time_range: Optional[int] = None # hours
|
||||
time_range: Optional[int] = None, # hours
|
||||
) -> Dict[str, Any]:
|
||||
"""Get settlement statistics"""
|
||||
conditions = []
|
||||
params = []
|
||||
param_count = 1
|
||||
|
||||
|
||||
if bridge_name:
|
||||
conditions.append(f"bridge_name = ${param_count}")
|
||||
params.append(bridge_name)
|
||||
param_count += 1
|
||||
|
||||
|
||||
if time_range:
|
||||
conditions.append(f"created_at > NOW() - INTERVAL '${param_count} hours'")
|
||||
params.append(time_range)
|
||||
param_count += 1
|
||||
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
|
||||
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
bridge_name,
|
||||
@@ -202,23 +203,27 @@ class SettlementStorage:
|
||||
{where_clause}
|
||||
GROUP BY bridge_name, status
|
||||
"""
|
||||
|
||||
|
||||
results = await self.db.fetch(query, *params)
|
||||
|
||||
|
||||
stats = {}
|
||||
for result in results:
|
||||
bridge = result['bridge_name']
|
||||
bridge = result["bridge_name"]
|
||||
if bridge not in stats:
|
||||
stats[bridge] = {}
|
||||
|
||||
stats[bridge][result['status']] = {
|
||||
'count': result['count'],
|
||||
'avg_amount': float(result['avg_amount']) if result['avg_amount'] else 0,
|
||||
'total_amount': float(result['total_amount']) if result['total_amount'] else 0
|
||||
|
||||
stats[bridge][result["status"]] = {
|
||||
"count": result["count"],
|
||||
"avg_amount": float(result["avg_amount"])
|
||||
if result["avg_amount"]
|
||||
else 0,
|
||||
"total_amount": float(result["total_amount"])
|
||||
if result["total_amount"]
|
||||
else 0,
|
||||
}
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def cleanup_old_settlements(self, days: int = 30) -> int:
|
||||
"""Clean up old completed settlements"""
|
||||
query = """
|
||||
@@ -226,7 +231,7 @@ class SettlementStorage:
|
||||
WHERE status IN ('completed', 'failed')
|
||||
AND created_at < NOW() - INTERVAL $1 days
|
||||
"""
|
||||
|
||||
|
||||
result = await self.db.execute(query, days)
|
||||
return result.split()[-1] # Return number of deleted rows
|
||||
|
||||
@@ -234,134 +239,139 @@ class SettlementStorage:
|
||||
# In-memory implementation for testing
|
||||
class InMemorySettlementStorage(SettlementStorage):
|
||||
"""In-memory storage implementation for testing"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.settlements: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def store_settlement(
|
||||
self,
|
||||
message_id: str,
|
||||
message: SettlementMessage,
|
||||
bridge_name: str,
|
||||
status: BridgeStatus
|
||||
status: BridgeStatus,
|
||||
) -> None:
|
||||
async with self._lock:
|
||||
self.settlements[message_id] = {
|
||||
'message_id': message_id,
|
||||
'job_id': message.job_id,
|
||||
'source_chain_id': message.source_chain_id,
|
||||
'target_chain_id': message.target_chain_id,
|
||||
'receipt_hash': message.receipt_hash,
|
||||
'proof_data': message.proof_data,
|
||||
'payment_amount': message.payment_amount,
|
||||
'payment_token': message.payment_token,
|
||||
'nonce': message.nonce,
|
||||
'signature': message.signature,
|
||||
'bridge_name': bridge_name,
|
||||
'status': status.value,
|
||||
'created_at': message.created_at or datetime.utcnow(),
|
||||
'updated_at': datetime.utcnow()
|
||||
"message_id": message_id,
|
||||
"job_id": message.job_id,
|
||||
"source_chain_id": message.source_chain_id,
|
||||
"target_chain_id": message.target_chain_id,
|
||||
"receipt_hash": message.receipt_hash,
|
||||
"proof_data": message.proof_data,
|
||||
"payment_amount": message.payment_amount,
|
||||
"payment_token": message.payment_token,
|
||||
"nonce": message.nonce,
|
||||
"signature": message.signature,
|
||||
"bridge_name": bridge_name,
|
||||
"status": status.value,
|
||||
"created_at": message.created_at or datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
|
||||
async def update_settlement(
|
||||
self,
|
||||
message_id: str,
|
||||
status: Optional[BridgeStatus] = None,
|
||||
transaction_hash: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
completed_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None,
|
||||
) -> None:
|
||||
async with self._lock:
|
||||
if message_id not in self.settlements:
|
||||
return
|
||||
|
||||
|
||||
settlement = self.settlements[message_id]
|
||||
|
||||
|
||||
if status is not None:
|
||||
settlement['status'] = status.value
|
||||
settlement["status"] = status.value
|
||||
if transaction_hash is not None:
|
||||
settlement['transaction_hash'] = transaction_hash
|
||||
settlement["transaction_hash"] = transaction_hash
|
||||
if error_message is not None:
|
||||
settlement['error_message'] = error_message
|
||||
settlement["error_message"] = error_message
|
||||
if completed_at is not None:
|
||||
settlement['completed_at'] = completed_at
|
||||
|
||||
settlement['updated_at'] = datetime.utcnow()
|
||||
|
||||
settlement["completed_at"] = completed_at
|
||||
|
||||
settlement["updated_at"] = datetime.utcnow()
|
||||
|
||||
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return self.settlements.get(message_id)
|
||||
|
||||
|
||||
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return [
|
||||
s for s in self.settlements.values()
|
||||
if s['job_id'] == job_id
|
||||
]
|
||||
|
||||
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
return [s for s in self.settlements.values() if s["job_id"] == job_id]
|
||||
|
||||
async def get_pending_settlements(
|
||||
self, bridge_name: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
pending = [
|
||||
s for s in self.settlements.values()
|
||||
if s['status'] in ['pending', 'in_progress']
|
||||
s
|
||||
for s in self.settlements.values()
|
||||
if s["status"] in ["pending", "in_progress"]
|
||||
]
|
||||
|
||||
|
||||
if bridge_name:
|
||||
pending = [s for s in pending if s['bridge_name'] == bridge_name]
|
||||
|
||||
pending = [s for s in pending if s["bridge_name"] == bridge_name]
|
||||
|
||||
return pending
|
||||
|
||||
|
||||
async def get_settlement_stats(
|
||||
self,
|
||||
bridge_name: Optional[str] = None,
|
||||
time_range: Optional[int] = None
|
||||
self, bridge_name: Optional[str] = None, time_range: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
async with self._lock:
|
||||
stats = {}
|
||||
|
||||
|
||||
for settlement in self.settlements.values():
|
||||
if bridge_name and settlement['bridge_name'] != bridge_name:
|
||||
if bridge_name and settlement["bridge_name"] != bridge_name:
|
||||
continue
|
||||
|
||||
# TODO: Implement time range filtering
|
||||
|
||||
bridge = settlement['bridge_name']
|
||||
|
||||
# Time range filtering
|
||||
if time_range is not None:
|
||||
cutoff = datetime.utcnow() - timedelta(hours=time_range)
|
||||
if settlement["created_at"] < cutoff:
|
||||
continue
|
||||
|
||||
bridge = settlement["bridge_name"]
|
||||
if bridge not in stats:
|
||||
stats[bridge] = {}
|
||||
|
||||
status = settlement['status']
|
||||
|
||||
status = settlement["status"]
|
||||
if status not in stats[bridge]:
|
||||
stats[bridge][status] = {
|
||||
'count': 0,
|
||||
'avg_amount': 0,
|
||||
'total_amount': 0
|
||||
"count": 0,
|
||||
"avg_amount": 0,
|
||||
"total_amount": 0,
|
||||
}
|
||||
|
||||
stats[bridge][status]['count'] += 1
|
||||
stats[bridge][status]['total_amount'] += settlement['payment_amount']
|
||||
|
||||
|
||||
stats[bridge][status]["count"] += 1
|
||||
stats[bridge][status]["total_amount"] += settlement["payment_amount"]
|
||||
|
||||
# Calculate averages
|
||||
for bridge_data in stats.values():
|
||||
for status_data in bridge_data.values():
|
||||
if status_data['count'] > 0:
|
||||
status_data['avg_amount'] = status_data['total_amount'] / status_data['count']
|
||||
|
||||
if status_data["count"] > 0:
|
||||
status_data["avg_amount"] = (
|
||||
status_data["total_amount"] / status_data["count"]
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def cleanup_old_settlements(self, days: int = 30) -> int:
|
||||
async with self._lock:
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
|
||||
to_delete = [
|
||||
msg_id for msg_id, settlement in self.settlements.items()
|
||||
msg_id
|
||||
for msg_id, settlement in self.settlements.items()
|
||||
if (
|
||||
settlement['status'] in ['completed', 'failed'] and
|
||||
settlement['created_at'] < cutoff
|
||||
settlement["status"] in ["completed", "failed"]
|
||||
and settlement["created_at"] < cutoff
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
for msg_id in to_delete:
|
||||
del self.settlements[msg_id]
|
||||
|
||||
|
||||
return len(to_delete)
|
||||
|
||||
@@ -4,115 +4,137 @@ Unified configuration for AITBC Coordinator API
|
||||
Provides environment-based adapter selection and consolidated settings.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
"""Database configuration with adapter selection."""
|
||||
|
||||
adapter: str = "sqlite" # sqlite, postgresql
|
||||
url: Optional[str] = None
|
||||
pool_size: int = 10
|
||||
max_overflow: int = 20
|
||||
pool_pre_ping: bool = True
|
||||
|
||||
|
||||
@property
|
||||
def effective_url(self) -> str:
|
||||
"""Get the effective database URL."""
|
||||
if self.url:
|
||||
return self.url
|
||||
|
||||
|
||||
# Default SQLite path
|
||||
if self.adapter == "sqlite":
|
||||
return "sqlite:///./coordinator.db"
|
||||
|
||||
|
||||
# Default PostgreSQL connection string
|
||||
return f"{self.adapter}://localhost:5432/coordinator"
|
||||
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="allow"
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow"
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Unified application settings with environment-based configuration."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="allow"
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow"
|
||||
)
|
||||
|
||||
# Environment
|
||||
app_env: str = "dev"
|
||||
app_host: str = "127.0.0.1"
|
||||
app_port: int = 8011
|
||||
|
||||
audit_log_dir: str = "/var/log/aitbc/audit"
|
||||
|
||||
# Database
|
||||
database: DatabaseConfig = DatabaseConfig()
|
||||
|
||||
|
||||
# API Keys
|
||||
client_api_keys: List[str] = []
|
||||
miner_api_keys: List[str] = []
|
||||
admin_api_keys: List[str] = []
|
||||
|
||||
|
||||
# Security
|
||||
hmac_secret: Optional[str] = None
|
||||
jwt_secret: Optional[str] = None
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_expiration_hours: int = 24
|
||||
|
||||
|
||||
# CORS
|
||||
allow_origins: List[str] = [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8080",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8011"
|
||||
"http://localhost:8011",
|
||||
]
|
||||
|
||||
|
||||
# Job Configuration
|
||||
job_ttl_seconds: int = 900
|
||||
heartbeat_interval_seconds: int = 10
|
||||
heartbeat_timeout_seconds: int = 30
|
||||
|
||||
|
||||
# Rate Limiting
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window_seconds: int = 60
|
||||
|
||||
|
||||
# Receipt Signing
|
||||
receipt_signing_key_hex: Optional[str] = None
|
||||
receipt_attestation_key_hex: Optional[str] = None
|
||||
|
||||
|
||||
# Logging
|
||||
log_level: str = "INFO"
|
||||
log_format: str = "json" # json or text
|
||||
|
||||
|
||||
# Mempool
|
||||
mempool_backend: str = "database" # database, memory
|
||||
|
||||
|
||||
# Blockchain RPC
|
||||
blockchain_rpc_url: str = "http://localhost:8082"
|
||||
|
||||
# Test Configuration
|
||||
test_mode: bool = False
|
||||
test_database_url: Optional[str] = None
|
||||
|
||||
def validate_secrets(self) -> None:
|
||||
"""Validate that all required secrets are provided."""
|
||||
if self.app_env == "production":
|
||||
if not self.jwt_secret:
|
||||
raise ValueError("JWT_SECRET environment variable is required in production")
|
||||
raise ValueError(
|
||||
"JWT_SECRET environment variable is required in production"
|
||||
)
|
||||
if self.jwt_secret == "change-me-in-production":
|
||||
raise ValueError("JWT_SECRET must be changed from default value")
|
||||
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""Get the database URL (backward compatibility)."""
|
||||
# Use test database if in test mode and test_database_url is set
|
||||
if self.test_mode and self.test_database_url:
|
||||
return self.test_database_url
|
||||
if self.database.url:
|
||||
return self.database.url
|
||||
# Default SQLite path for backward compatibility
|
||||
return f"sqlite:///./aitbc_coordinator.db"
|
||||
|
||||
@database_url.setter
|
||||
def database_url(self, value: str):
|
||||
"""Allow setting database URL for tests"""
|
||||
if not self.test_mode:
|
||||
raise RuntimeError("Cannot set database_url outside of test mode")
|
||||
self.test_database_url = value
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Enable test mode if environment variable is set
|
||||
if os.getenv("TEST_MODE") == "true":
|
||||
settings.test_mode = True
|
||||
if os.getenv("TEST_DATABASE_URL"):
|
||||
settings.test_database_url = os.getenv("TEST_DATABASE_URL")
|
||||
|
||||
# Validate secrets on import
|
||||
settings.validate_secrets()
|
||||
|
||||
@@ -52,6 +52,7 @@ from ..schemas import (
|
||||
from ..domain import (
|
||||
Job,
|
||||
Miner,
|
||||
JobReceipt,
|
||||
MarketplaceOffer,
|
||||
MarketplaceBid,
|
||||
User,
|
||||
@@ -93,6 +94,7 @@ __all__ = [
|
||||
"Constraints",
|
||||
"Job",
|
||||
"Miner",
|
||||
"JobReceipt",
|
||||
"MarketplaceOffer",
|
||||
"MarketplaceBid",
|
||||
"ServiceType",
|
||||
|
||||
@@ -22,6 +22,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class AuditEvent:
|
||||
"""Structured audit event"""
|
||||
|
||||
event_id: str
|
||||
timestamp: datetime
|
||||
event_type: str
|
||||
@@ -39,27 +40,38 @@ class AuditEvent:
|
||||
|
||||
class AuditLogger:
|
||||
"""Tamper-evident audit logging for privacy compliance"""
|
||||
|
||||
def __init__(self, log_dir: str = "/var/log/aitbc/audit"):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __init__(self, log_dir: str = None):
|
||||
# Use test-specific directory if in test environment
|
||||
if os.getenv("PYTEST_CURRENT_TEST"):
|
||||
# Use project logs directory for tests
|
||||
# Navigate from coordinator-api/src/app/services/audit_logging.py to project root
|
||||
# Path: coordinator-api/src/app/services/audit_logging.py -> apps/coordinator-api/src -> apps/coordinator-api -> apps -> project_root
|
||||
project_root = Path(__file__).resolve().parent.parent.parent.parent.parent.parent
|
||||
test_log_dir = project_root / "logs" / "audit"
|
||||
log_path = log_dir or str(test_log_dir)
|
||||
else:
|
||||
log_path = log_dir or settings.audit_log_dir
|
||||
|
||||
self.log_dir = Path(log_path)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Current log file
|
||||
self.current_file = None
|
||||
self.current_hash = None
|
||||
|
||||
|
||||
# Async writer task
|
||||
self.write_queue = asyncio.Queue(maxsize=10000)
|
||||
self.writer_task = None
|
||||
|
||||
|
||||
# Chain of hashes for integrity
|
||||
self.chain_hash = self._load_chain_hash()
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""Start the background writer task"""
|
||||
if self.writer_task is None:
|
||||
self.writer_task = asyncio.create_task(self._background_writer())
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the background writer task"""
|
||||
if self.writer_task:
|
||||
@@ -69,7 +81,7 @@ class AuditLogger:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.writer_task = None
|
||||
|
||||
|
||||
async def log_access(
|
||||
self,
|
||||
participant_id: str,
|
||||
@@ -79,7 +91,7 @@ class AuditLogger:
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
authorization: Optional[str] = None
|
||||
authorization: Optional[str] = None,
|
||||
):
|
||||
"""Log access to confidential data"""
|
||||
event = AuditEvent(
|
||||
@@ -95,22 +107,22 @@ class AuditLogger:
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
authorization=authorization,
|
||||
signature=None
|
||||
signature=None,
|
||||
)
|
||||
|
||||
|
||||
# Add signature for tamper-evidence
|
||||
event.signature = self._sign_event(event)
|
||||
|
||||
|
||||
# Queue for writing
|
||||
await self.write_queue.put(event)
|
||||
|
||||
|
||||
async def log_key_operation(
|
||||
self,
|
||||
participant_id: str,
|
||||
operation: str,
|
||||
key_version: int,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Log key management operations"""
|
||||
event = AuditEvent(
|
||||
@@ -126,19 +138,19 @@ class AuditLogger:
|
||||
ip_address=None,
|
||||
user_agent=None,
|
||||
authorization=None,
|
||||
signature=None
|
||||
signature=None,
|
||||
)
|
||||
|
||||
|
||||
event.signature = self._sign_event(event)
|
||||
await self.write_queue.put(event)
|
||||
|
||||
|
||||
async def log_policy_change(
|
||||
self,
|
||||
participant_id: str,
|
||||
policy_id: str,
|
||||
change_type: str,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Log access policy changes"""
|
||||
event = AuditEvent(
|
||||
@@ -154,12 +166,12 @@ class AuditLogger:
|
||||
ip_address=None,
|
||||
user_agent=None,
|
||||
authorization=None,
|
||||
signature=None
|
||||
signature=None,
|
||||
)
|
||||
|
||||
|
||||
event.signature = self._sign_event(event)
|
||||
await self.write_queue.put(event)
|
||||
|
||||
|
||||
def query_logs(
|
||||
self,
|
||||
participant_id: Optional[str] = None,
|
||||
@@ -167,14 +179,14 @@ class AuditLogger:
|
||||
event_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 100
|
||||
limit: int = 100,
|
||||
) -> List[AuditEvent]:
|
||||
"""Query audit logs"""
|
||||
results = []
|
||||
|
||||
|
||||
# Get list of log files to search
|
||||
log_files = self._get_log_files(start_time, end_time)
|
||||
|
||||
|
||||
for log_file in log_files:
|
||||
try:
|
||||
# Read and decompress if needed
|
||||
@@ -182,7 +194,14 @@ class AuditLogger:
|
||||
with gzip.open(log_file, "rt") as f:
|
||||
for line in f:
|
||||
event = self._parse_log_line(line.strip())
|
||||
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
|
||||
if self._matches_query(
|
||||
event,
|
||||
participant_id,
|
||||
transaction_id,
|
||||
event_type,
|
||||
start_time,
|
||||
end_time,
|
||||
):
|
||||
results.append(event)
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
@@ -190,75 +209,79 @@ class AuditLogger:
|
||||
with open(log_file, "r") as f:
|
||||
for line in f:
|
||||
event = self._parse_log_line(line.strip())
|
||||
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
|
||||
if self._matches_query(
|
||||
event,
|
||||
participant_id,
|
||||
transaction_id,
|
||||
event_type,
|
||||
start_time,
|
||||
end_time,
|
||||
):
|
||||
results.append(event)
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read log file {log_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
results.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
|
||||
return results[:limit]
|
||||
|
||||
|
||||
def verify_integrity(self, start_date: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Verify integrity of audit logs"""
|
||||
if start_date is None:
|
||||
start_date = datetime.utcnow() - timedelta(days=30)
|
||||
|
||||
|
||||
results = {
|
||||
"verified_files": 0,
|
||||
"total_files": 0,
|
||||
"integrity_violations": [],
|
||||
"chain_valid": True
|
||||
"chain_valid": True,
|
||||
}
|
||||
|
||||
|
||||
log_files = self._get_log_files(start_date)
|
||||
|
||||
|
||||
for log_file in log_files:
|
||||
results["total_files"] += 1
|
||||
|
||||
|
||||
try:
|
||||
# Verify file hash
|
||||
file_hash = self._calculate_file_hash(log_file)
|
||||
stored_hash = self._get_stored_hash(log_file)
|
||||
|
||||
|
||||
if file_hash != stored_hash:
|
||||
results["integrity_violations"].append({
|
||||
"file": str(log_file),
|
||||
"expected": stored_hash,
|
||||
"actual": file_hash
|
||||
})
|
||||
results["integrity_violations"].append(
|
||||
{
|
||||
"file": str(log_file),
|
||||
"expected": stored_hash,
|
||||
"actual": file_hash,
|
||||
}
|
||||
)
|
||||
results["chain_valid"] = False
|
||||
else:
|
||||
results["verified_files"] += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify {log_file}: {e}")
|
||||
results["integrity_violations"].append({
|
||||
"file": str(log_file),
|
||||
"error": str(e)
|
||||
})
|
||||
results["integrity_violations"].append(
|
||||
{"file": str(log_file), "error": str(e)}
|
||||
)
|
||||
results["chain_valid"] = False
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def export_logs(
|
||||
self,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
format: str = "json",
|
||||
include_signatures: bool = True
|
||||
include_signatures: bool = True,
|
||||
) -> str:
|
||||
"""Export audit logs for compliance reporting"""
|
||||
events = self.query_logs(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=10000
|
||||
)
|
||||
|
||||
events = self.query_logs(start_time=start_time, end_time=end_time, limit=10000)
|
||||
|
||||
if format == "json":
|
||||
export_data = {
|
||||
"export_metadata": {
|
||||
@@ -266,39 +289,46 @@ class AuditLogger:
|
||||
"end_time": end_time.isoformat(),
|
||||
"event_count": len(events),
|
||||
"exported_at": datetime.utcnow().isoformat(),
|
||||
"include_signatures": include_signatures
|
||||
"include_signatures": include_signatures,
|
||||
},
|
||||
"events": []
|
||||
"events": [],
|
||||
}
|
||||
|
||||
|
||||
for event in events:
|
||||
event_dict = asdict(event)
|
||||
event_dict["timestamp"] = event.timestamp.isoformat()
|
||||
|
||||
|
||||
if not include_signatures:
|
||||
event_dict.pop("signature", None)
|
||||
|
||||
|
||||
export_data["events"].append(event_dict)
|
||||
|
||||
|
||||
return json.dumps(export_data, indent=2)
|
||||
|
||||
|
||||
elif format == "csv":
|
||||
import csv
|
||||
import io
|
||||
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
|
||||
# Header
|
||||
header = [
|
||||
"event_id", "timestamp", "event_type", "participant_id",
|
||||
"transaction_id", "action", "resource", "outcome",
|
||||
"ip_address", "user_agent"
|
||||
"event_id",
|
||||
"timestamp",
|
||||
"event_type",
|
||||
"participant_id",
|
||||
"transaction_id",
|
||||
"action",
|
||||
"resource",
|
||||
"outcome",
|
||||
"ip_address",
|
||||
"user_agent",
|
||||
]
|
||||
if include_signatures:
|
||||
header.append("signature")
|
||||
writer.writerow(header)
|
||||
|
||||
|
||||
# Events
|
||||
for event in events:
|
||||
row = [
|
||||
@@ -311,17 +341,17 @@ class AuditLogger:
|
||||
event.resource,
|
||||
event.outcome,
|
||||
event.ip_address,
|
||||
event.user_agent
|
||||
event.user_agent,
|
||||
]
|
||||
if include_signatures:
|
||||
row.append(event.signature)
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported export format: {format}")
|
||||
|
||||
|
||||
async def _background_writer(self):
|
||||
"""Background task for writing audit events"""
|
||||
while True:
|
||||
@@ -332,51 +362,50 @@ class AuditLogger:
|
||||
try:
|
||||
# Use asyncio.wait_for for timeout
|
||||
event = await asyncio.wait_for(
|
||||
self.write_queue.get(),
|
||||
timeout=1.0
|
||||
self.write_queue.get(), timeout=1.0
|
||||
)
|
||||
events.append(event)
|
||||
except asyncio.TimeoutError:
|
||||
if events:
|
||||
break
|
||||
continue
|
||||
|
||||
|
||||
# Write events
|
||||
if events:
|
||||
self._write_events(events)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background writer error: {e}")
|
||||
# Brief pause to avoid error loops
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
def _write_events(self, events: List[AuditEvent]):
|
||||
"""Write events to current log file"""
|
||||
try:
|
||||
self._rotate_if_needed()
|
||||
|
||||
|
||||
with open(self.current_file, "a") as f:
|
||||
for event in events:
|
||||
# Convert to JSON line
|
||||
event_dict = asdict(event)
|
||||
event_dict["timestamp"] = event.timestamp.isoformat()
|
||||
|
||||
|
||||
# Write with signature
|
||||
line = json.dumps(event_dict, separators=(",", ":")) + "\n"
|
||||
f.write(line)
|
||||
f.flush()
|
||||
|
||||
|
||||
# Update chain hash
|
||||
self._update_chain_hash(events[-1])
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audit events: {e}")
|
||||
|
||||
|
||||
def _rotate_if_needed(self):
|
||||
"""Rotate log file if needed"""
|
||||
now = datetime.utcnow()
|
||||
today = now.date()
|
||||
|
||||
|
||||
# Check if we need a new file
|
||||
if self.current_file is None:
|
||||
self._new_log_file(today)
|
||||
@@ -384,31 +413,31 @@ class AuditLogger:
|
||||
file_date = datetime.fromisoformat(
|
||||
self.current_file.stem.split("_")[1]
|
||||
).date()
|
||||
|
||||
|
||||
if file_date != today:
|
||||
self._new_log_file(today)
|
||||
|
||||
|
||||
def _new_log_file(self, date):
|
||||
"""Create new log file for date"""
|
||||
filename = f"audit_{date.isoformat()}.log"
|
||||
self.current_file = self.log_dir / filename
|
||||
|
||||
|
||||
# Write header with metadata
|
||||
if not self.current_file.exists():
|
||||
header = {
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"version": "1.0",
|
||||
"format": "jsonl",
|
||||
"previous_hash": self.chain_hash
|
||||
"previous_hash": self.chain_hash,
|
||||
}
|
||||
|
||||
|
||||
with open(self.current_file, "w") as f:
|
||||
f.write(f"# {json.dumps(header)}\n")
|
||||
|
||||
|
||||
def _generate_event_id(self) -> str:
|
||||
"""Generate unique event ID"""
|
||||
return f"evt_{datetime.utcnow().timestamp()}_{os.urandom(4).hex()}"
|
||||
|
||||
|
||||
def _sign_event(self, event: AuditEvent) -> str:
|
||||
"""Sign event for tamper-evidence"""
|
||||
# Create canonical representation
|
||||
@@ -417,24 +446,24 @@ class AuditLogger:
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"participant_id": event.participant_id,
|
||||
"action": event.action,
|
||||
"outcome": event.outcome
|
||||
"outcome": event.outcome,
|
||||
}
|
||||
|
||||
|
||||
# Hash with previous chain hash
|
||||
data = json.dumps(event_data, separators=(",", ":"), sort_keys=True)
|
||||
combined = f"{self.chain_hash}:{data}".encode()
|
||||
|
||||
|
||||
return hashlib.sha256(combined).hexdigest()
|
||||
|
||||
|
||||
def _update_chain_hash(self, last_event: AuditEvent):
|
||||
"""Update chain hash with new event"""
|
||||
self.chain_hash = last_event.signature or self.chain_hash
|
||||
|
||||
|
||||
# Store chain hash for integrity checking
|
||||
chain_file = self.log_dir / "chain.hash"
|
||||
with open(chain_file, "w") as f:
|
||||
f.write(self.chain_hash)
|
||||
|
||||
|
||||
def _load_chain_hash(self) -> str:
|
||||
"""Load previous chain hash"""
|
||||
chain_file = self.log_dir / "chain.hash"
|
||||
@@ -442,35 +471,38 @@ class AuditLogger:
|
||||
with open(chain_file, "r") as f:
|
||||
return f.read().strip()
|
||||
return "0" * 64 # Initial hash
|
||||
|
||||
def _get_log_files(self, start_time: Optional[datetime], end_time: Optional[datetime]) -> List[Path]:
|
||||
|
||||
def _get_log_files(
|
||||
self, start_time: Optional[datetime], end_time: Optional[datetime]
|
||||
) -> List[Path]:
|
||||
"""Get list of log files to search"""
|
||||
files = []
|
||||
|
||||
|
||||
for file in self.log_dir.glob("audit_*.log*"):
|
||||
try:
|
||||
# Extract date from filename
|
||||
date_str = file.stem.split("_")[1]
|
||||
file_date = datetime.fromisoformat(date_str).date()
|
||||
|
||||
|
||||
# Check if file is in range
|
||||
file_start = datetime.combine(file_date, datetime.min.time())
|
||||
file_end = file_start + timedelta(days=1)
|
||||
|
||||
if (not start_time or file_end >= start_time) and \
|
||||
(not end_time or file_start <= end_time):
|
||||
|
||||
if (not start_time or file_end >= start_time) and (
|
||||
not end_time or file_start <= end_time
|
||||
):
|
||||
files.append(file)
|
||||
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
return sorted(files)
|
||||
|
||||
|
||||
def _parse_log_line(self, line: str) -> Optional[AuditEvent]:
|
||||
"""Parse log line into event"""
|
||||
if line.startswith("#"):
|
||||
return None # Skip header
|
||||
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
|
||||
@@ -478,7 +510,7 @@ class AuditLogger:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse log line: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _matches_query(
|
||||
self,
|
||||
event: Optional[AuditEvent],
|
||||
@@ -486,39 +518,39 @@ class AuditLogger:
|
||||
transaction_id: Optional[str],
|
||||
event_type: Optional[str],
|
||||
start_time: Optional[datetime],
|
||||
end_time: Optional[datetime]
|
||||
end_time: Optional[datetime],
|
||||
) -> bool:
|
||||
"""Check if event matches query criteria"""
|
||||
if not event:
|
||||
return False
|
||||
|
||||
|
||||
if participant_id and event.participant_id != participant_id:
|
||||
return False
|
||||
|
||||
|
||||
if transaction_id and event.transaction_id != transaction_id:
|
||||
return False
|
||||
|
||||
|
||||
if event_type and event.event_type != event_type:
|
||||
return False
|
||||
|
||||
|
||||
if start_time and event.timestamp < start_time:
|
||||
return False
|
||||
|
||||
|
||||
if end_time and event.timestamp > end_time:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _calculate_file_hash(self, file_path: Path) -> str:
|
||||
"""Calculate SHA-256 hash of file"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def _get_stored_hash(self, file_path: Path) -> str:
|
||||
"""Get stored hash for file"""
|
||||
hash_file = file_path.with_suffix(".hash")
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Confidential Transaction Service - Wrapper for existing confidential functionality
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from ..services.encryption import EncryptionService
|
||||
from ..services.key_management import KeyManager
|
||||
from ..models.confidential import ConfidentialTransaction, ViewingKey
|
||||
|
||||
|
||||
class ConfidentialTransactionService:
|
||||
"""Service for handling confidential transactions using existing encryption and key management"""
|
||||
|
||||
def __init__(self):
|
||||
self.encryption_service = EncryptionService()
|
||||
self.key_manager = KeyManager()
|
||||
|
||||
def create_confidential_transaction(
|
||||
self,
|
||||
sender: str,
|
||||
recipient: str,
|
||||
amount: int,
|
||||
viewing_key: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> ConfidentialTransaction:
|
||||
"""Create a new confidential transaction"""
|
||||
# Generate viewing key if not provided
|
||||
if not viewing_key:
|
||||
viewing_key = self.key_manager.generate_viewing_key()
|
||||
|
||||
# Encrypt transaction data
|
||||
encrypted_data = self.encryption_service.encrypt_transaction_data({
|
||||
"sender": sender,
|
||||
"recipient": recipient,
|
||||
"amount": amount,
|
||||
"metadata": metadata or {}
|
||||
})
|
||||
|
||||
return ConfidentialTransaction(
|
||||
sender=sender,
|
||||
recipient=recipient,
|
||||
encrypted_payload=encrypted_data,
|
||||
viewing_key=viewing_key,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
def decrypt_transaction(
|
||||
self,
|
||||
transaction: ConfidentialTransaction,
|
||||
viewing_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Decrypt a confidential transaction using viewing key"""
|
||||
return self.encryption_service.decrypt_transaction_data(
|
||||
transaction.encrypted_payload,
|
||||
viewing_key
|
||||
)
|
||||
|
||||
def verify_transaction_access(
|
||||
self,
|
||||
transaction: ConfidentialTransaction,
|
||||
requester: str
|
||||
) -> bool:
|
||||
"""Verify if requester has access to view transaction"""
|
||||
return requester in [transaction.sender, transaction.recipient]
|
||||
|
||||
def get_transaction_summary(
|
||||
self,
|
||||
transaction: ConfidentialTransaction,
|
||||
viewer: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get transaction summary based on viewer permissions"""
|
||||
if self.verify_transaction_access(transaction, viewer):
|
||||
return self.decrypt_transaction(transaction, transaction.viewing_key)
|
||||
else:
|
||||
return {
|
||||
"transaction_id": transaction.id,
|
||||
"encrypted": True,
|
||||
"accessible": False
|
||||
}
|
||||
@@ -11,10 +11,18 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import (
|
||||
X25519PrivateKey,
|
||||
X25519PublicKey,
|
||||
)
|
||||
from cryptography.hazmat.primitives.serialization import (
|
||||
Encoding,
|
||||
PublicFormat,
|
||||
PrivateFormat,
|
||||
NoEncryption,
|
||||
)
|
||||
|
||||
from ..schemas import ConfidentialTransaction, AccessLog
|
||||
from ..schemas import ConfidentialTransaction, ConfidentialAccessLog
|
||||
from ..config import settings
|
||||
from ..logging import get_logger
|
||||
|
||||
@@ -23,21 +31,21 @@ logger = get_logger(__name__)
|
||||
|
||||
class EncryptedData:
|
||||
"""Container for encrypted data and keys"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ciphertext: bytes,
|
||||
encrypted_keys: Dict[str, bytes],
|
||||
algorithm: str = "AES-256-GCM+X25519",
|
||||
nonce: Optional[bytes] = None,
|
||||
tag: Optional[bytes] = None
|
||||
tag: Optional[bytes] = None,
|
||||
):
|
||||
self.ciphertext = ciphertext
|
||||
self.encrypted_keys = encrypted_keys
|
||||
self.algorithm = algorithm
|
||||
self.nonce = nonce
|
||||
self.tag = tag
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for storage"""
|
||||
return {
|
||||
@@ -48,9 +56,9 @@ class EncryptedData:
|
||||
},
|
||||
"algorithm": self.algorithm,
|
||||
"nonce": base64.b64encode(self.nonce).decode() if self.nonce else None,
|
||||
"tag": base64.b64encode(self.tag).decode() if self.tag else None
|
||||
"tag": base64.b64encode(self.tag).decode() if self.tag else None,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "EncryptedData":
|
||||
"""Create from dictionary"""
|
||||
@@ -62,31 +70,28 @@ class EncryptedData:
|
||||
},
|
||||
algorithm=data["algorithm"],
|
||||
nonce=base64.b64decode(data["nonce"]) if data.get("nonce") else None,
|
||||
tag=base64.b64decode(data["tag"]) if data.get("tag") else None
|
||||
tag=base64.b64decode(data["tag"]) if data.get("tag") else None,
|
||||
)
|
||||
|
||||
|
||||
class EncryptionService:
|
||||
"""Service for encrypting/decrypting confidential transaction data"""
|
||||
|
||||
|
||||
def __init__(self, key_manager: "KeyManager"):
|
||||
self.key_manager = key_manager
|
||||
self.backend = default_backend()
|
||||
self.algorithm = "AES-256-GCM+X25519"
|
||||
|
||||
|
||||
def encrypt(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
participants: List[str],
|
||||
include_audit: bool = True
|
||||
self, data: Dict[str, Any], participants: List[str], include_audit: bool = True
|
||||
) -> EncryptedData:
|
||||
"""Encrypt data for multiple participants
|
||||
|
||||
|
||||
Args:
|
||||
data: Data to encrypt
|
||||
participants: List of participant IDs who can decrypt
|
||||
include_audit: Whether to include audit escrow key
|
||||
|
||||
|
||||
Returns:
|
||||
EncryptedData container with ciphertext and encrypted keys
|
||||
"""
|
||||
@@ -94,16 +99,16 @@ class EncryptionService:
|
||||
# Generate random DEK (Data Encryption Key)
|
||||
dek = os.urandom(32) # 256-bit key for AES-256
|
||||
nonce = os.urandom(12) # 96-bit nonce for GCM
|
||||
|
||||
|
||||
# Serialize and encrypt data
|
||||
plaintext = json.dumps(data, separators=(",", ":")).encode()
|
||||
aesgcm = AESGCM(dek)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
|
||||
|
||||
|
||||
# Extract tag (included in ciphertext for GCM)
|
||||
tag = ciphertext[-16:]
|
||||
actual_ciphertext = ciphertext[:-16]
|
||||
|
||||
|
||||
# Encrypt DEK for each participant
|
||||
encrypted_keys = {}
|
||||
for participant in participants:
|
||||
@@ -112,9 +117,11 @@ class EncryptionService:
|
||||
encrypted_dek = self._encrypt_dek(dek, public_key)
|
||||
encrypted_keys[participant] = encrypted_dek
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt DEK for participant {participant}: {e}")
|
||||
logger.error(
|
||||
f"Failed to encrypt DEK for participant {participant}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
# Add audit escrow if requested
|
||||
if include_audit:
|
||||
try:
|
||||
@@ -123,67 +130,67 @@ class EncryptionService:
|
||||
encrypted_keys["audit"] = encrypted_dek
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt DEK for audit: {e}")
|
||||
|
||||
|
||||
return EncryptedData(
|
||||
ciphertext=actual_ciphertext,
|
||||
encrypted_keys=encrypted_keys,
|
||||
algorithm=self.algorithm,
|
||||
nonce=nonce,
|
||||
tag=tag
|
||||
tag=tag,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Encryption failed: {e}")
|
||||
raise EncryptionError(f"Failed to encrypt data: {e}")
|
||||
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
encrypted_data: EncryptedData,
|
||||
participant_id: str,
|
||||
purpose: str = "access"
|
||||
purpose: str = "access",
|
||||
) -> Dict[str, Any]:
|
||||
"""Decrypt data for a specific participant
|
||||
|
||||
|
||||
Args:
|
||||
encrypted_data: The encrypted data container
|
||||
participant_id: ID of the participant requesting decryption
|
||||
purpose: Purpose of decryption for audit logging
|
||||
|
||||
|
||||
Returns:
|
||||
Decrypted data as dictionary
|
||||
"""
|
||||
try:
|
||||
# Get participant's private key
|
||||
private_key = self.key_manager.get_private_key(participant_id)
|
||||
|
||||
|
||||
# Get encrypted DEK for participant
|
||||
if participant_id not in encrypted_data.encrypted_keys:
|
||||
raise AccessDeniedError(f"Participant {participant_id} not authorized")
|
||||
|
||||
|
||||
encrypted_dek = encrypted_data.encrypted_keys[participant_id]
|
||||
|
||||
|
||||
# Decrypt DEK
|
||||
dek = self._decrypt_dek(encrypted_dek, private_key)
|
||||
|
||||
|
||||
# Reconstruct ciphertext with tag
|
||||
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
|
||||
|
||||
|
||||
# Decrypt data
|
||||
aesgcm = AESGCM(dek)
|
||||
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
|
||||
|
||||
|
||||
data = json.loads(plaintext.decode())
|
||||
|
||||
|
||||
# Log access
|
||||
self._log_access(
|
||||
transaction_id=None, # Will be set by caller
|
||||
participant_id=participant_id,
|
||||
purpose=purpose,
|
||||
success=True
|
||||
success=True,
|
||||
)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Decryption failed for participant {participant_id}: {e}")
|
||||
self._log_access(
|
||||
@@ -191,23 +198,23 @@ class EncryptionService:
|
||||
participant_id=participant_id,
|
||||
purpose=purpose,
|
||||
success=False,
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
raise DecryptionError(f"Failed to decrypt data: {e}")
|
||||
|
||||
|
||||
def audit_decrypt(
|
||||
self,
|
||||
encrypted_data: EncryptedData,
|
||||
audit_authorization: str,
|
||||
purpose: str = "audit"
|
||||
purpose: str = "audit",
|
||||
) -> Dict[str, Any]:
|
||||
"""Decrypt data for audit purposes
|
||||
|
||||
|
||||
Args:
|
||||
encrypted_data: The encrypted data container
|
||||
audit_authorization: Authorization token for audit access
|
||||
purpose: Purpose of decryption
|
||||
|
||||
|
||||
Returns:
|
||||
Decrypted data as dictionary
|
||||
"""
|
||||
@@ -215,97 +222,101 @@ class EncryptionService:
|
||||
# Verify audit authorization
|
||||
if not self.key_manager.verify_audit_authorization(audit_authorization):
|
||||
raise AccessDeniedError("Invalid audit authorization")
|
||||
|
||||
|
||||
# Get audit private key
|
||||
audit_private_key = self.key_manager.get_audit_private_key(audit_authorization)
|
||||
|
||||
audit_private_key = self.key_manager.get_audit_private_key(
|
||||
audit_authorization
|
||||
)
|
||||
|
||||
# Decrypt using audit key
|
||||
if "audit" not in encrypted_data.encrypted_keys:
|
||||
raise AccessDeniedError("Audit escrow not available")
|
||||
|
||||
|
||||
encrypted_dek = encrypted_data.encrypted_keys["audit"]
|
||||
dek = self._decrypt_dek(encrypted_dek, audit_private_key)
|
||||
|
||||
|
||||
# Decrypt data
|
||||
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
|
||||
aesgcm = AESGCM(dek)
|
||||
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
|
||||
|
||||
|
||||
data = json.loads(plaintext.decode())
|
||||
|
||||
|
||||
# Log audit access
|
||||
self._log_access(
|
||||
transaction_id=None,
|
||||
participant_id="audit",
|
||||
purpose=f"audit:{purpose}",
|
||||
success=True,
|
||||
authorization=audit_authorization
|
||||
authorization=audit_authorization,
|
||||
)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Audit decryption failed: {e}")
|
||||
raise DecryptionError(f"Failed to decrypt for audit: {e}")
|
||||
|
||||
|
||||
def _encrypt_dek(self, dek: bytes, public_key: X25519PublicKey) -> bytes:
|
||||
"""Encrypt DEK using ECIES with X25519"""
|
||||
# Generate ephemeral key pair
|
||||
ephemeral_private = X25519PrivateKey.generate()
|
||||
ephemeral_public = ephemeral_private.public_key()
|
||||
|
||||
|
||||
# Perform ECDH
|
||||
shared_key = ephemeral_private.exchange(public_key)
|
||||
|
||||
|
||||
# Derive encryption key from shared secret
|
||||
derived_key = HKDF(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=None,
|
||||
info=b"AITBC-DEK-Encryption",
|
||||
backend=self.backend
|
||||
backend=self.backend,
|
||||
).derive(shared_key)
|
||||
|
||||
|
||||
# Encrypt DEK with AES-GCM
|
||||
aesgcm = AESGCM(derived_key)
|
||||
nonce = os.urandom(12)
|
||||
encrypted_dek = aesgcm.encrypt(nonce, dek, None)
|
||||
|
||||
|
||||
# Return ephemeral public key + nonce + encrypted DEK
|
||||
return (
|
||||
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw) +
|
||||
nonce +
|
||||
encrypted_dek
|
||||
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw)
|
||||
+ nonce
|
||||
+ encrypted_dek
|
||||
)
|
||||
|
||||
def _decrypt_dek(self, encrypted_dek: bytes, private_key: X25519PrivateKey) -> bytes:
|
||||
|
||||
def _decrypt_dek(
|
||||
self, encrypted_dek: bytes, private_key: X25519PrivateKey
|
||||
) -> bytes:
|
||||
"""Decrypt DEK using ECIES with X25519"""
|
||||
# Extract components
|
||||
ephemeral_public_bytes = encrypted_dek[:32]
|
||||
nonce = encrypted_dek[32:44]
|
||||
dek_ciphertext = encrypted_dek[44:]
|
||||
|
||||
|
||||
# Reconstruct ephemeral public key
|
||||
ephemeral_public = X25519PublicKey.from_public_bytes(ephemeral_public_bytes)
|
||||
|
||||
|
||||
# Perform ECDH
|
||||
shared_key = private_key.exchange(ephemeral_public)
|
||||
|
||||
|
||||
# Derive decryption key
|
||||
derived_key = HKDF(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=None,
|
||||
info=b"AITBC-DEK-Encryption",
|
||||
backend=self.backend
|
||||
backend=self.backend,
|
||||
).derive(shared_key)
|
||||
|
||||
|
||||
# Decrypt DEK
|
||||
aesgcm = AESGCM(derived_key)
|
||||
dek = aesgcm.decrypt(nonce, dek_ciphertext, None)
|
||||
|
||||
|
||||
return dek
|
||||
|
||||
|
||||
def _log_access(
|
||||
self,
|
||||
transaction_id: Optional[str],
|
||||
@@ -313,7 +324,7 @@ class EncryptionService:
|
||||
purpose: str,
|
||||
success: bool,
|
||||
error: Optional[str] = None,
|
||||
authorization: Optional[str] = None
|
||||
authorization: Optional[str] = None,
|
||||
):
|
||||
"""Log access to confidential data"""
|
||||
try:
|
||||
@@ -324,26 +335,29 @@ class EncryptionService:
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"success": success,
|
||||
"error": error,
|
||||
"authorization": authorization
|
||||
"authorization": authorization,
|
||||
}
|
||||
|
||||
|
||||
# In production, this would go to secure audit log
|
||||
logger.info(f"Confidential data access: {json.dumps(log_entry)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log access: {e}")
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Base exception for encryption errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DecryptionError(EncryptionError):
|
||||
"""Exception for decryption errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AccessDeniedError(EncryptionError):
|
||||
"""Exception for access denied errors"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from ..config import settings
|
||||
from ..domain import Job, JobReceipt
|
||||
from ..schemas import (
|
||||
BlockListResponse,
|
||||
@@ -39,29 +40,45 @@ class ExplorerService:
|
||||
self.session = session
|
||||
|
||||
def list_blocks(self, *, limit: int = 20, offset: int = 0) -> BlockListResponse:
|
||||
# Fetch real blockchain data from RPC API
|
||||
# Fetch real blockchain data via /rpc/head and /rpc/blocks-range
|
||||
rpc_base = settings.blockchain_rpc_url.rstrip("/")
|
||||
try:
|
||||
# Use the blockchain RPC API running on localhost:8082
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
response = client.get("http://localhost:8082/rpc/blocks", params={"limit": limit, "offset": offset})
|
||||
response.raise_for_status()
|
||||
rpc_data = response.json()
|
||||
|
||||
head_resp = client.get(f"{rpc_base}/rpc/head")
|
||||
if head_resp.status_code == 404:
|
||||
return BlockListResponse(items=[], next_offset=None)
|
||||
head_resp.raise_for_status()
|
||||
head = head_resp.json()
|
||||
height = head.get("height", 0)
|
||||
start = max(0, height - offset - limit + 1)
|
||||
end = height - offset
|
||||
if start > end:
|
||||
return BlockListResponse(items=[], next_offset=None)
|
||||
range_resp = client.get(
|
||||
f"{rpc_base}/rpc/blocks-range",
|
||||
params={"start": start, "end": end},
|
||||
)
|
||||
range_resp.raise_for_status()
|
||||
rpc_data = range_resp.json()
|
||||
raw_blocks = rpc_data.get("blocks", [])
|
||||
# Node returns ascending by height; explorer expects newest first
|
||||
raw_blocks = list(reversed(raw_blocks))
|
||||
items: list[BlockSummary] = []
|
||||
for block in rpc_data.get("blocks", []):
|
||||
for block in raw_blocks:
|
||||
ts = block.get("timestamp")
|
||||
if isinstance(ts, str):
|
||||
ts = datetime.fromisoformat(ts.replace("Z", "+00:00"))
|
||||
items.append(
|
||||
BlockSummary(
|
||||
height=block["height"],
|
||||
hash=block["hash"],
|
||||
timestamp=datetime.fromisoformat(block["timestamp"]),
|
||||
txCount=block["tx_count"],
|
||||
proposer=block["proposer"],
|
||||
timestamp=ts,
|
||||
txCount=block.get("tx_count", 0),
|
||||
proposer=block.get("proposer", "—"),
|
||||
)
|
||||
)
|
||||
|
||||
next_offset: Optional[int] = offset + len(items) if len(items) == limit else None
|
||||
next_offset = offset + len(items) if len(items) == limit else None
|
||||
return BlockListResponse(items=items, next_offset=next_offset)
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to fake data if RPC is unavailable
|
||||
print(f"Warning: Failed to fetch blocks from RPC: {e}, falling back to fake data")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Ensure coordinator-api src is on sys.path for all tests in this directory."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
_src = str(Path(__file__).resolve().parent.parent / "src")
|
||||
@@ -15,3 +17,9 @@ if _app_mod and hasattr(_app_mod, "__file__") and _app_mod.__file__ and _src not
|
||||
|
||||
if _src not in sys.path:
|
||||
sys.path.insert(0, _src)
|
||||
|
||||
# Set up test environment
|
||||
os.environ["TEST_MODE"] = "true"
|
||||
project_root = Path(__file__).resolve().parent.parent.parent
|
||||
os.environ["AUDIT_LOG_DIR"] = str(project_root / "logs" / "audit")
|
||||
os.environ["TEST_DATABASE_URL"] = "sqlite:///:memory:"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from sqlmodel import Session, delete
|
||||
from sqlmodel import Session, delete, text
|
||||
|
||||
from app.domain import Job, Miner
|
||||
from app.models import JobCreate
|
||||
@@ -14,7 +14,26 @@ def _init_db(tmp_path_factory):
|
||||
from app.config import settings
|
||||
|
||||
settings.database_url = f"sqlite:///{db_file}"
|
||||
|
||||
# Initialize database and create tables
|
||||
init_db()
|
||||
|
||||
# Ensure payment_id column exists (handle schema migration)
|
||||
with session_scope() as sess:
|
||||
try:
|
||||
# Check if columns exist and add them if needed
|
||||
result = sess.exec(text("PRAGMA table_info(job)"))
|
||||
columns = [row[1] for row in result.fetchall()]
|
||||
|
||||
if 'payment_id' not in columns:
|
||||
sess.exec(text("ALTER TABLE job ADD COLUMN payment_id TEXT"))
|
||||
if 'payment_status' not in columns:
|
||||
sess.exec(text("ALTER TABLE job ADD COLUMN payment_status TEXT"))
|
||||
sess.commit()
|
||||
except Exception as e:
|
||||
print(f"Schema migration error: {e}")
|
||||
sess.rollback()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -9,19 +9,18 @@ from pathlib import Path
|
||||
|
||||
from app.services.zk_proofs import ZKProofService
|
||||
from app.models import JobReceipt, Job, JobResult
|
||||
from app.domain import ReceiptPayload
|
||||
|
||||
|
||||
class TestZKProofService:
|
||||
"""Test cases for ZK proof service"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zk_service(self):
|
||||
"""Create ZK proof service instance"""
|
||||
with patch('app.services.zk_proofs.settings'):
|
||||
with patch("app.services.zk_proofs.settings"):
|
||||
service = ZKProofService()
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_job(self):
|
||||
"""Create sample job for testing"""
|
||||
@@ -31,9 +30,9 @@ class TestZKProofService:
|
||||
payload={"type": "test"},
|
||||
constraints={},
|
||||
requested_at=None,
|
||||
completed=True
|
||||
completed=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_job_result(self):
|
||||
"""Create sample job result"""
|
||||
@@ -42,9 +41,9 @@ class TestZKProofService:
|
||||
"result_hash": "0x1234567890abcdef",
|
||||
"units": 100,
|
||||
"unit_type": "gpu_seconds",
|
||||
"metrics": {"execution_time": 5.0}
|
||||
"metrics": {"execution_time": 5.0},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_receipt(self, sample_job):
|
||||
"""Create sample receipt"""
|
||||
@@ -59,171 +58,187 @@ class TestZKProofService:
|
||||
price="0.1",
|
||||
started_at=1640995200,
|
||||
completed_at=1640995800,
|
||||
metadata={}
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
return JobReceipt(
|
||||
job_id=sample_job.id,
|
||||
receipt_id=payload.receipt_id,
|
||||
payload=payload.dict()
|
||||
job_id=sample_job.id, receipt_id=payload.receipt_id, payload=payload.dict()
|
||||
)
|
||||
|
||||
|
||||
def test_service_initialization_with_files(self):
|
||||
"""Test service initialization when circuit files exist"""
|
||||
with patch('app.services.zk_proofs.Path') as mock_path:
|
||||
with patch("app.services.zk_proofs.Path") as mock_path:
|
||||
# Mock file existence
|
||||
mock_path.return_value.exists.return_value = True
|
||||
|
||||
|
||||
service = ZKProofService()
|
||||
assert service.enabled is True
|
||||
|
||||
|
||||
def test_service_initialization_without_files(self):
|
||||
"""Test service initialization when circuit files are missing"""
|
||||
with patch('app.services.zk_proofs.Path') as mock_path:
|
||||
with patch("app.services.zk_proofs.Path") as mock_path:
|
||||
# Mock file non-existence
|
||||
mock_path.return_value.exists.return_value = False
|
||||
|
||||
|
||||
service = ZKProofService()
|
||||
assert service.enabled is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_proof_basic_privacy(self, zk_service, sample_receipt, sample_job_result):
|
||||
async def test_generate_proof_basic_privacy(
|
||||
self, zk_service, sample_receipt, sample_job_result
|
||||
):
|
||||
"""Test generating proof with basic privacy level"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
|
||||
# Mock subprocess calls
|
||||
with patch('subprocess.run') as mock_run:
|
||||
with patch("subprocess.run") as mock_run:
|
||||
# Mock successful proof generation
|
||||
mock_run.return_value.returncode = 0
|
||||
mock_run.return_value.stdout = json.dumps({
|
||||
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
"publicSignals": ["0x1234", "1000", "1640995800"]
|
||||
})
|
||||
|
||||
mock_run.return_value.stdout = json.dumps(
|
||||
{
|
||||
"proof": {
|
||||
"a": ["1", "2"],
|
||||
"b": [["1", "2"], ["1", "2"]],
|
||||
"c": ["1", "2"],
|
||||
},
|
||||
"publicSignals": ["0x1234", "1000", "1640995800"],
|
||||
}
|
||||
)
|
||||
|
||||
# Generate proof
|
||||
proof = await zk_service.generate_receipt_proof(
|
||||
receipt=sample_receipt,
|
||||
job_result=sample_job_result,
|
||||
privacy_level="basic"
|
||||
privacy_level="basic",
|
||||
)
|
||||
|
||||
|
||||
assert proof is not None
|
||||
assert "proof" in proof
|
||||
assert "public_signals" in proof
|
||||
assert proof["privacy_level"] == "basic"
|
||||
assert "circuit_hash" in proof
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_proof_enhanced_privacy(self, zk_service, sample_receipt, sample_job_result):
|
||||
async def test_generate_proof_enhanced_privacy(
|
||||
self, zk_service, sample_receipt, sample_job_result
|
||||
):
|
||||
"""Test generating proof with enhanced privacy level"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
with patch('subprocess.run') as mock_run:
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value.returncode = 0
|
||||
mock_run.return_value.stdout = json.dumps({
|
||||
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
"publicSignals": ["1000", "1640995800"]
|
||||
})
|
||||
|
||||
mock_run.return_value.stdout = json.dumps(
|
||||
{
|
||||
"proof": {
|
||||
"a": ["1", "2"],
|
||||
"b": [["1", "2"], ["1", "2"]],
|
||||
"c": ["1", "2"],
|
||||
},
|
||||
"publicSignals": ["1000", "1640995800"],
|
||||
}
|
||||
)
|
||||
|
||||
proof = await zk_service.generate_receipt_proof(
|
||||
receipt=sample_receipt,
|
||||
job_result=sample_job_result,
|
||||
privacy_level="enhanced"
|
||||
privacy_level="enhanced",
|
||||
)
|
||||
|
||||
|
||||
assert proof is not None
|
||||
assert proof["privacy_level"] == "enhanced"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_proof_service_disabled(self, zk_service, sample_receipt, sample_job_result):
|
||||
async def test_generate_proof_service_disabled(
|
||||
self, zk_service, sample_receipt, sample_job_result
|
||||
):
|
||||
"""Test proof generation when service is disabled"""
|
||||
zk_service.enabled = False
|
||||
|
||||
|
||||
proof = await zk_service.generate_receipt_proof(
|
||||
receipt=sample_receipt,
|
||||
job_result=sample_job_result,
|
||||
privacy_level="basic"
|
||||
receipt=sample_receipt, job_result=sample_job_result, privacy_level="basic"
|
||||
)
|
||||
|
||||
|
||||
assert proof is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_proof_invalid_privacy_level(self, zk_service, sample_receipt, sample_job_result):
|
||||
async def test_generate_proof_invalid_privacy_level(
|
||||
self, zk_service, sample_receipt, sample_job_result
|
||||
):
|
||||
"""Test proof generation with invalid privacy level"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown privacy level"):
|
||||
await zk_service.generate_receipt_proof(
|
||||
receipt=sample_receipt,
|
||||
job_result=sample_job_result,
|
||||
privacy_level="invalid"
|
||||
privacy_level="invalid",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_proof_success(self, zk_service):
|
||||
"""Test successful proof verification"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
with patch('subprocess.run') as mock_run, \
|
||||
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
|
||||
|
||||
|
||||
with patch("subprocess.run") as mock_run, patch(
|
||||
"builtins.open", mock_open(read_data='{"key": "value"}')
|
||||
):
|
||||
mock_run.return_value.returncode = 0
|
||||
mock_run.return_value.stdout = "true"
|
||||
|
||||
|
||||
result = await zk_service.verify_proof(
|
||||
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
public_signals=["0x1234", "1000"]
|
||||
public_signals=["0x1234", "1000"],
|
||||
)
|
||||
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_proof_failure(self, zk_service):
|
||||
"""Test proof verification failure"""
|
||||
if not zk_service.enabled:
|
||||
pytest.skip("ZK circuits not available")
|
||||
|
||||
with patch('subprocess.run') as mock_run, \
|
||||
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
|
||||
|
||||
|
||||
with patch("subprocess.run") as mock_run, patch(
|
||||
"builtins.open", mock_open(read_data='{"key": "value"}')
|
||||
):
|
||||
mock_run.return_value.returncode = 1
|
||||
mock_run.return_value.stderr = "Verification failed"
|
||||
|
||||
|
||||
result = await zk_service.verify_proof(
|
||||
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
public_signals=["0x1234", "1000"]
|
||||
public_signals=["0x1234", "1000"],
|
||||
)
|
||||
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_proof_service_disabled(self, zk_service):
|
||||
"""Test proof verification when service is disabled"""
|
||||
zk_service.enabled = False
|
||||
|
||||
|
||||
result = await zk_service.verify_proof(
|
||||
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
|
||||
public_signals=["0x1234", "1000"]
|
||||
public_signals=["0x1234", "1000"],
|
||||
)
|
||||
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_hash_receipt(self, zk_service, sample_receipt):
|
||||
"""Test receipt hashing"""
|
||||
receipt_hash = zk_service._hash_receipt(sample_receipt)
|
||||
|
||||
|
||||
assert isinstance(receipt_hash, str)
|
||||
assert len(receipt_hash) == 64 # SHA256 hex length
|
||||
assert all(c in '0123456789abcdef' for c in receipt_hash)
|
||||
|
||||
assert all(c in "0123456789abcdef" for c in receipt_hash)
|
||||
|
||||
def test_serialize_receipt(self, zk_service, sample_receipt):
|
||||
"""Test receipt serialization for circuit"""
|
||||
serialized = zk_service._serialize_receipt(sample_receipt)
|
||||
|
||||
|
||||
assert isinstance(serialized, list)
|
||||
assert len(serialized) == 8
|
||||
assert all(isinstance(x, str) for x in serialized)
|
||||
@@ -231,19 +246,19 @@ class TestZKProofService:
|
||||
|
||||
class TestZKProofIntegration:
|
||||
"""Integration tests for ZK proof system"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receipt_creation_with_zk_proof(self):
|
||||
"""Test receipt creation with ZK proof generation"""
|
||||
from app.services.receipts import ReceiptService
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
# Create mock session
|
||||
session = Mock(spec=Session)
|
||||
|
||||
|
||||
# Create receipt service
|
||||
receipt_service = ReceiptService(session)
|
||||
|
||||
|
||||
# Create sample job
|
||||
job = Job(
|
||||
id="test-job-123",
|
||||
@@ -251,43 +266,45 @@ class TestZKProofIntegration:
|
||||
payload={"type": "test"},
|
||||
constraints={},
|
||||
requested_at=None,
|
||||
completed=True
|
||||
completed=True,
|
||||
)
|
||||
|
||||
|
||||
# Mock ZK proof service
|
||||
with patch('app.services.receipts.zk_proof_service') as mock_zk:
|
||||
with patch("app.services.receipts.zk_proof_service") as mock_zk:
|
||||
mock_zk.is_enabled.return_value = True
|
||||
mock_zk.generate_receipt_proof = AsyncMock(return_value={
|
||||
"proof": {"a": ["1", "2"]},
|
||||
"public_signals": ["0x1234"],
|
||||
"privacy_level": "basic"
|
||||
})
|
||||
|
||||
mock_zk.generate_receipt_proof = AsyncMock(
|
||||
return_value={
|
||||
"proof": {"a": ["1", "2"]},
|
||||
"public_signals": ["0x1234"],
|
||||
"privacy_level": "basic",
|
||||
}
|
||||
)
|
||||
|
||||
# Create receipt with privacy
|
||||
receipt = await receipt_service.create_receipt(
|
||||
job=job,
|
||||
miner_id="miner-001",
|
||||
job_result={"result": "test"},
|
||||
result_metrics={"units": 100},
|
||||
privacy_level="basic"
|
||||
privacy_level="basic",
|
||||
)
|
||||
|
||||
|
||||
assert receipt is not None
|
||||
assert "zk_proof" in receipt
|
||||
assert receipt["privacy_level"] == "basic"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settlement_with_zk_proof(self):
|
||||
"""Test cross-chain settlement with ZK proof"""
|
||||
from aitbc.settlement.hooks import SettlementHook
|
||||
from aitbc.settlement.manager import BridgeManager
|
||||
|
||||
|
||||
# Create mock bridge manager
|
||||
bridge_manager = Mock(spec=BridgeManager)
|
||||
|
||||
|
||||
# Create settlement hook
|
||||
settlement_hook = SettlementHook(bridge_manager)
|
||||
|
||||
|
||||
# Create sample job with ZK proof
|
||||
job = Job(
|
||||
id="test-job-123",
|
||||
@@ -296,9 +313,9 @@ class TestZKProofIntegration:
|
||||
constraints={},
|
||||
requested_at=None,
|
||||
completed=True,
|
||||
target_chain=2
|
||||
target_chain=2,
|
||||
)
|
||||
|
||||
|
||||
# Create receipt with ZK proof
|
||||
receipt_payload = {
|
||||
"version": "1.0",
|
||||
@@ -306,24 +323,20 @@ class TestZKProofIntegration:
|
||||
"job_id": job.id,
|
||||
"provider": "miner-001",
|
||||
"client": job.client_id,
|
||||
"zk_proof": {
|
||||
"proof": {"a": ["1", "2"]},
|
||||
"public_signals": ["0x1234"]
|
||||
}
|
||||
"zk_proof": {"proof": {"a": ["1", "2"]}, "public_signals": ["0x1234"]},
|
||||
}
|
||||
|
||||
|
||||
job.receipt = JobReceipt(
|
||||
job_id=job.id,
|
||||
receipt_id=receipt_payload["receipt_id"],
|
||||
payload=receipt_payload
|
||||
payload=receipt_payload,
|
||||
)
|
||||
|
||||
|
||||
# Test settlement message creation
|
||||
message = await settlement_hook._create_settlement_message(
|
||||
job,
|
||||
options={"use_zk_proof": True, "privacy_level": "basic"}
|
||||
job, options={"use_zk_proof": True, "privacy_level": "basic"}
|
||||
)
|
||||
|
||||
|
||||
assert message.zk_proof is not None
|
||||
assert message.privacy_level == "basic"
|
||||
|
||||
@@ -332,71 +345,70 @@ class TestZKProofIntegration:
|
||||
def mock_open(read_data=""):
|
||||
"""Mock open function for file operations"""
|
||||
from unittest.mock import mock_open
|
||||
|
||||
return mock_open(read_data=read_data)
|
||||
|
||||
|
||||
# Benchmark tests
|
||||
class TestZKProofPerformance:
|
||||
"""Performance benchmarks for ZK proof operations"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proof_generation_time(self):
|
||||
"""Benchmark proof generation time"""
|
||||
import time
|
||||
|
||||
|
||||
if not Path("apps/zk-circuits/receipt.wasm").exists():
|
||||
pytest.skip("ZK circuits not built")
|
||||
|
||||
|
||||
service = ZKProofService()
|
||||
if not service.enabled:
|
||||
pytest.skip("ZK service not enabled")
|
||||
|
||||
|
||||
# Create test data
|
||||
receipt = JobReceipt(
|
||||
job_id="benchmark-job",
|
||||
receipt_id="benchmark-receipt",
|
||||
payload={"test": "data"}
|
||||
payload={"test": "data"},
|
||||
)
|
||||
|
||||
|
||||
job_result = {"result": "benchmark"}
|
||||
|
||||
|
||||
# Measure proof generation time
|
||||
start_time = time.time()
|
||||
proof = await service.generate_receipt_proof(
|
||||
receipt=receipt,
|
||||
job_result=job_result,
|
||||
privacy_level="basic"
|
||||
receipt=receipt, job_result=job_result, privacy_level="basic"
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
generation_time = end_time - start_time
|
||||
|
||||
|
||||
assert proof is not None
|
||||
assert generation_time < 30 # Should complete within 30 seconds
|
||||
|
||||
|
||||
print(f"Proof generation time: {generation_time:.2f} seconds")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proof_verification_time(self):
|
||||
"""Benchmark proof verification time"""
|
||||
import time
|
||||
|
||||
|
||||
service = ZKProofService()
|
||||
if not service.enabled:
|
||||
pytest.skip("ZK service not enabled")
|
||||
|
||||
|
||||
# Create test proof
|
||||
proof = {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]}
|
||||
public_signals = ["0x1234", "1000"]
|
||||
|
||||
|
||||
# Measure verification time
|
||||
start_time = time.time()
|
||||
result = await service.verify_proof(proof, public_signals)
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
verification_time = end_time - start_time
|
||||
|
||||
|
||||
assert isinstance(result, bool)
|
||||
assert verification_time < 1 # Should complete within 1 second
|
||||
|
||||
|
||||
print(f"Proof verification time: {verification_time:.3f} seconds")
|
||||
|
||||
Reference in New Issue
Block a user