feat: add transaction hash search to blockchain explorer and cleanup settlement storage

Blockchain Explorer:
- Add transaction hash search support (64-char hex pattern validation)
- Fetch and display transaction details in modal (hash, type, from/to, amount, fee, block)
- Fix regex escape sequence in block height validation
- Update search placeholder text to mention both search types
- Add blank lines between function definitions for PEP 8 compliance

Settlement Storage:
- Add timedelta import for future
This commit is contained in:
oib
2026-02-17 14:34:12 +01:00
parent 31d3d70836
commit 421191ccaf
34 changed files with 2176 additions and 5660 deletions

View File

@@ -3,30 +3,26 @@ Storage layer for cross-chain settlements
"""
from typing import Dict, Any, Optional, List
from datetime import datetime
from datetime import datetime, timedelta
import json
import asyncio
from dataclasses import asdict
from .bridges.base import (
SettlementMessage,
SettlementResult,
BridgeStatus
)
from .bridges.base import SettlementMessage, SettlementResult, BridgeStatus
class SettlementStorage:
"""Storage interface for settlement data"""
def __init__(self, db_connection):
self.db = db_connection
async def store_settlement(
self,
message_id: str,
message: SettlementMessage,
bridge_name: str,
status: BridgeStatus
status: BridgeStatus,
) -> None:
"""Store a new settlement record"""
query = """
@@ -38,93 +34,96 @@ class SettlementStorage:
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
)
"""
await self.db.execute(query, (
message_id,
message.job_id,
message.source_chain_id,
message.target_chain_id,
message.receipt_hash,
json.dumps(message.proof_data),
message.payment_amount,
message.payment_token,
message.nonce,
message.signature,
bridge_name,
status.value,
message.created_at or datetime.utcnow()
))
await self.db.execute(
query,
(
message_id,
message.job_id,
message.source_chain_id,
message.target_chain_id,
message.receipt_hash,
json.dumps(message.proof_data),
message.payment_amount,
message.payment_token,
message.nonce,
message.signature,
bridge_name,
status.value,
message.created_at or datetime.utcnow(),
),
)
async def update_settlement(
self,
message_id: str,
status: Optional[BridgeStatus] = None,
transaction_hash: Optional[str] = None,
error_message: Optional[str] = None,
completed_at: Optional[datetime] = None
completed_at: Optional[datetime] = None,
) -> None:
"""Update settlement record"""
updates = []
params = []
param_count = 1
if status is not None:
updates.append(f"status = ${param_count}")
params.append(status.value)
param_count += 1
if transaction_hash is not None:
updates.append(f"transaction_hash = ${param_count}")
params.append(transaction_hash)
param_count += 1
if error_message is not None:
updates.append(f"error_message = ${param_count}")
params.append(error_message)
param_count += 1
if completed_at is not None:
updates.append(f"completed_at = ${param_count}")
params.append(completed_at)
param_count += 1
if not updates:
return
updates.append(f"updated_at = ${param_count}")
params.append(datetime.utcnow())
param_count += 1
params.append(message_id)
query = f"""
UPDATE settlements
SET {', '.join(updates)}
SET {", ".join(updates)}
WHERE message_id = ${param_count}
"""
await self.db.execute(query, params)
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
"""Get settlement by message ID"""
query = """
SELECT * FROM settlements WHERE message_id = $1
"""
result = await self.db.fetchrow(query, message_id)
if not result:
return None
# Convert to dict
settlement = dict(result)
# Parse JSON fields
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
if settlement["proof_data"]:
settlement["proof_data"] = json.loads(settlement["proof_data"])
return settlement
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
"""Get all settlements for a job"""
query = """
@@ -132,65 +131,67 @@ class SettlementStorage:
WHERE job_id = $1
ORDER BY created_at DESC
"""
results = await self.db.fetch(query, job_id)
settlements = []
for result in results:
settlement = dict(result)
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
if settlement["proof_data"]:
settlement["proof_data"] = json.loads(settlement["proof_data"])
settlements.append(settlement)
return settlements
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
async def get_pending_settlements(
self, bridge_name: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Get all pending settlements"""
query = """
SELECT * FROM settlements
WHERE status = 'pending' OR status = 'in_progress'
"""
params = []
if bridge_name:
query += " AND bridge_name = $1"
params.append(bridge_name)
query += " ORDER BY created_at ASC"
results = await self.db.fetch(query, *params)
settlements = []
for result in results:
settlement = dict(result)
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
if settlement["proof_data"]:
settlement["proof_data"] = json.loads(settlement["proof_data"])
settlements.append(settlement)
return settlements
async def get_settlement_stats(
self,
bridge_name: Optional[str] = None,
time_range: Optional[int] = None # hours
time_range: Optional[int] = None, # hours
) -> Dict[str, Any]:
"""Get settlement statistics"""
conditions = []
params = []
param_count = 1
if bridge_name:
conditions.append(f"bridge_name = ${param_count}")
params.append(bridge_name)
param_count += 1
if time_range:
conditions.append(f"created_at > NOW() - INTERVAL '${param_count} hours'")
params.append(time_range)
param_count += 1
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
query = f"""
SELECT
bridge_name,
@@ -202,23 +203,27 @@ class SettlementStorage:
{where_clause}
GROUP BY bridge_name, status
"""
results = await self.db.fetch(query, *params)
stats = {}
for result in results:
bridge = result['bridge_name']
bridge = result["bridge_name"]
if bridge not in stats:
stats[bridge] = {}
stats[bridge][result['status']] = {
'count': result['count'],
'avg_amount': float(result['avg_amount']) if result['avg_amount'] else 0,
'total_amount': float(result['total_amount']) if result['total_amount'] else 0
stats[bridge][result["status"]] = {
"count": result["count"],
"avg_amount": float(result["avg_amount"])
if result["avg_amount"]
else 0,
"total_amount": float(result["total_amount"])
if result["total_amount"]
else 0,
}
return stats
async def cleanup_old_settlements(self, days: int = 30) -> int:
"""Clean up old completed settlements"""
query = """
@@ -226,7 +231,7 @@ class SettlementStorage:
WHERE status IN ('completed', 'failed')
AND created_at < NOW() - INTERVAL $1 days
"""
result = await self.db.execute(query, days)
return result.split()[-1] # Return number of deleted rows
@@ -234,134 +239,139 @@ class SettlementStorage:
# In-memory implementation for testing
class InMemorySettlementStorage(SettlementStorage):
"""In-memory storage implementation for testing"""
def __init__(self):
self.settlements: Dict[str, Dict[str, Any]] = {}
self._lock = asyncio.Lock()
async def store_settlement(
self,
message_id: str,
message: SettlementMessage,
bridge_name: str,
status: BridgeStatus
status: BridgeStatus,
) -> None:
async with self._lock:
self.settlements[message_id] = {
'message_id': message_id,
'job_id': message.job_id,
'source_chain_id': message.source_chain_id,
'target_chain_id': message.target_chain_id,
'receipt_hash': message.receipt_hash,
'proof_data': message.proof_data,
'payment_amount': message.payment_amount,
'payment_token': message.payment_token,
'nonce': message.nonce,
'signature': message.signature,
'bridge_name': bridge_name,
'status': status.value,
'created_at': message.created_at or datetime.utcnow(),
'updated_at': datetime.utcnow()
"message_id": message_id,
"job_id": message.job_id,
"source_chain_id": message.source_chain_id,
"target_chain_id": message.target_chain_id,
"receipt_hash": message.receipt_hash,
"proof_data": message.proof_data,
"payment_amount": message.payment_amount,
"payment_token": message.payment_token,
"nonce": message.nonce,
"signature": message.signature,
"bridge_name": bridge_name,
"status": status.value,
"created_at": message.created_at or datetime.utcnow(),
"updated_at": datetime.utcnow(),
}
async def update_settlement(
self,
message_id: str,
status: Optional[BridgeStatus] = None,
transaction_hash: Optional[str] = None,
error_message: Optional[str] = None,
completed_at: Optional[datetime] = None
completed_at: Optional[datetime] = None,
) -> None:
async with self._lock:
if message_id not in self.settlements:
return
settlement = self.settlements[message_id]
if status is not None:
settlement['status'] = status.value
settlement["status"] = status.value
if transaction_hash is not None:
settlement['transaction_hash'] = transaction_hash
settlement["transaction_hash"] = transaction_hash
if error_message is not None:
settlement['error_message'] = error_message
settlement["error_message"] = error_message
if completed_at is not None:
settlement['completed_at'] = completed_at
settlement['updated_at'] = datetime.utcnow()
settlement["completed_at"] = completed_at
settlement["updated_at"] = datetime.utcnow()
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self.settlements.get(message_id)
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
async with self._lock:
return [
s for s in self.settlements.values()
if s['job_id'] == job_id
]
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
return [s for s in self.settlements.values() if s["job_id"] == job_id]
async def get_pending_settlements(
self, bridge_name: Optional[str] = None
) -> List[Dict[str, Any]]:
async with self._lock:
pending = [
s for s in self.settlements.values()
if s['status'] in ['pending', 'in_progress']
s
for s in self.settlements.values()
if s["status"] in ["pending", "in_progress"]
]
if bridge_name:
pending = [s for s in pending if s['bridge_name'] == bridge_name]
pending = [s for s in pending if s["bridge_name"] == bridge_name]
return pending
async def get_settlement_stats(
self,
bridge_name: Optional[str] = None,
time_range: Optional[int] = None
self, bridge_name: Optional[str] = None, time_range: Optional[int] = None
) -> Dict[str, Any]:
async with self._lock:
stats = {}
for settlement in self.settlements.values():
if bridge_name and settlement['bridge_name'] != bridge_name:
if bridge_name and settlement["bridge_name"] != bridge_name:
continue
# TODO: Implement time range filtering
bridge = settlement['bridge_name']
# Time range filtering
if time_range is not None:
cutoff = datetime.utcnow() - timedelta(hours=time_range)
if settlement["created_at"] < cutoff:
continue
bridge = settlement["bridge_name"]
if bridge not in stats:
stats[bridge] = {}
status = settlement['status']
status = settlement["status"]
if status not in stats[bridge]:
stats[bridge][status] = {
'count': 0,
'avg_amount': 0,
'total_amount': 0
"count": 0,
"avg_amount": 0,
"total_amount": 0,
}
stats[bridge][status]['count'] += 1
stats[bridge][status]['total_amount'] += settlement['payment_amount']
stats[bridge][status]["count"] += 1
stats[bridge][status]["total_amount"] += settlement["payment_amount"]
# Calculate averages
for bridge_data in stats.values():
for status_data in bridge_data.values():
if status_data['count'] > 0:
status_data['avg_amount'] = status_data['total_amount'] / status_data['count']
if status_data["count"] > 0:
status_data["avg_amount"] = (
status_data["total_amount"] / status_data["count"]
)
return stats
async def cleanup_old_settlements(self, days: int = 30) -> int:
async with self._lock:
cutoff = datetime.utcnow() - timedelta(days=days)
to_delete = [
msg_id for msg_id, settlement in self.settlements.items()
msg_id
for msg_id, settlement in self.settlements.items()
if (
settlement['status'] in ['completed', 'failed'] and
settlement['created_at'] < cutoff
settlement["status"] in ["completed", "failed"]
and settlement["created_at"] < cutoff
)
]
for msg_id in to_delete:
del self.settlements[msg_id]
return len(to_delete)

View File

@@ -4,115 +4,137 @@ Unified configuration for AITBC Coordinator API
Provides environment-based adapter selection and consolidated settings.
"""
import os
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Optional
from pathlib import Path
import os
class DatabaseConfig(BaseSettings):
"""Database configuration with adapter selection."""
adapter: str = "sqlite" # sqlite, postgresql
url: Optional[str] = None
pool_size: int = 10
max_overflow: int = 20
pool_pre_ping: bool = True
@property
def effective_url(self) -> str:
"""Get the effective database URL."""
if self.url:
return self.url
# Default SQLite path
if self.adapter == "sqlite":
return "sqlite:///./coordinator.db"
# Default PostgreSQL connection string
return f"{self.adapter}://localhost:5432/coordinator"
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="allow"
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow"
)
class Settings(BaseSettings):
"""Unified application settings with environment-based configuration."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="allow"
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow"
)
# Environment
app_env: str = "dev"
app_host: str = "127.0.0.1"
app_port: int = 8011
audit_log_dir: str = "/var/log/aitbc/audit"
# Database
database: DatabaseConfig = DatabaseConfig()
# API Keys
client_api_keys: List[str] = []
miner_api_keys: List[str] = []
admin_api_keys: List[str] = []
# Security
hmac_secret: Optional[str] = None
jwt_secret: Optional[str] = None
jwt_algorithm: str = "HS256"
jwt_expiration_hours: int = 24
# CORS
allow_origins: List[str] = [
"http://localhost:3000",
"http://localhost:8080",
"http://localhost:8000",
"http://localhost:8011"
"http://localhost:8011",
]
# Job Configuration
job_ttl_seconds: int = 900
heartbeat_interval_seconds: int = 10
heartbeat_timeout_seconds: int = 30
# Rate Limiting
rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60
# Receipt Signing
receipt_signing_key_hex: Optional[str] = None
receipt_attestation_key_hex: Optional[str] = None
# Logging
log_level: str = "INFO"
log_format: str = "json" # json or text
# Mempool
mempool_backend: str = "database" # database, memory
# Blockchain RPC
blockchain_rpc_url: str = "http://localhost:8082"
# Test Configuration
test_mode: bool = False
test_database_url: Optional[str] = None
def validate_secrets(self) -> None:
"""Validate that all required secrets are provided."""
if self.app_env == "production":
if not self.jwt_secret:
raise ValueError("JWT_SECRET environment variable is required in production")
raise ValueError(
"JWT_SECRET environment variable is required in production"
)
if self.jwt_secret == "change-me-in-production":
raise ValueError("JWT_SECRET must be changed from default value")
@property
def database_url(self) -> str:
"""Get the database URL (backward compatibility)."""
# Use test database if in test mode and test_database_url is set
if self.test_mode and self.test_database_url:
return self.test_database_url
if self.database.url:
return self.database.url
# Default SQLite path for backward compatibility
return f"sqlite:///./aitbc_coordinator.db"
@database_url.setter
def database_url(self, value: str):
"""Allow setting database URL for tests"""
if not self.test_mode:
raise RuntimeError("Cannot set database_url outside of test mode")
self.test_database_url = value
settings = Settings()
# Enable test mode if environment variable is set
if os.getenv("TEST_MODE") == "true":
settings.test_mode = True
if os.getenv("TEST_DATABASE_URL"):
settings.test_database_url = os.getenv("TEST_DATABASE_URL")
# Validate secrets on import
settings.validate_secrets()

View File

@@ -52,6 +52,7 @@ from ..schemas import (
from ..domain import (
Job,
Miner,
JobReceipt,
MarketplaceOffer,
MarketplaceBid,
User,
@@ -93,6 +94,7 @@ __all__ = [
"Constraints",
"Job",
"Miner",
"JobReceipt",
"MarketplaceOffer",
"MarketplaceBid",
"ServiceType",

View File

@@ -22,6 +22,7 @@ logger = get_logger(__name__)
@dataclass
class AuditEvent:
"""Structured audit event"""
event_id: str
timestamp: datetime
event_type: str
@@ -39,27 +40,38 @@ class AuditEvent:
class AuditLogger:
"""Tamper-evident audit logging for privacy compliance"""
def __init__(self, log_dir: str = "/var/log/aitbc/audit"):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
def __init__(self, log_dir: str = None):
# Use test-specific directory if in test environment
if os.getenv("PYTEST_CURRENT_TEST"):
# Use project logs directory for tests
# Navigate from coordinator-api/src/app/services/audit_logging.py to project root
# Path: coordinator-api/src/app/services/audit_logging.py -> apps/coordinator-api/src -> apps/coordinator-api -> apps -> project_root
project_root = Path(__file__).resolve().parent.parent.parent.parent.parent.parent
test_log_dir = project_root / "logs" / "audit"
log_path = log_dir or str(test_log_dir)
else:
log_path = log_dir or settings.audit_log_dir
self.log_dir = Path(log_path)
self.log_dir.mkdir(parents=True, exist_ok=True)
# Current log file
self.current_file = None
self.current_hash = None
# Async writer task
self.write_queue = asyncio.Queue(maxsize=10000)
self.writer_task = None
# Chain of hashes for integrity
self.chain_hash = self._load_chain_hash()
async def start(self):
"""Start the background writer task"""
if self.writer_task is None:
self.writer_task = asyncio.create_task(self._background_writer())
async def stop(self):
"""Stop the background writer task"""
if self.writer_task:
@@ -69,7 +81,7 @@ class AuditLogger:
except asyncio.CancelledError:
pass
self.writer_task = None
async def log_access(
self,
participant_id: str,
@@ -79,7 +91,7 @@ class AuditLogger:
details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
authorization: Optional[str] = None
authorization: Optional[str] = None,
):
"""Log access to confidential data"""
event = AuditEvent(
@@ -95,22 +107,22 @@ class AuditLogger:
ip_address=ip_address,
user_agent=user_agent,
authorization=authorization,
signature=None
signature=None,
)
# Add signature for tamper-evidence
event.signature = self._sign_event(event)
# Queue for writing
await self.write_queue.put(event)
async def log_key_operation(
self,
participant_id: str,
operation: str,
key_version: int,
outcome: str,
details: Optional[Dict[str, Any]] = None
details: Optional[Dict[str, Any]] = None,
):
"""Log key management operations"""
event = AuditEvent(
@@ -126,19 +138,19 @@ class AuditLogger:
ip_address=None,
user_agent=None,
authorization=None,
signature=None
signature=None,
)
event.signature = self._sign_event(event)
await self.write_queue.put(event)
async def log_policy_change(
self,
participant_id: str,
policy_id: str,
change_type: str,
outcome: str,
details: Optional[Dict[str, Any]] = None
details: Optional[Dict[str, Any]] = None,
):
"""Log access policy changes"""
event = AuditEvent(
@@ -154,12 +166,12 @@ class AuditLogger:
ip_address=None,
user_agent=None,
authorization=None,
signature=None
signature=None,
)
event.signature = self._sign_event(event)
await self.write_queue.put(event)
def query_logs(
self,
participant_id: Optional[str] = None,
@@ -167,14 +179,14 @@ class AuditLogger:
event_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 100
limit: int = 100,
) -> List[AuditEvent]:
"""Query audit logs"""
results = []
# Get list of log files to search
log_files = self._get_log_files(start_time, end_time)
for log_file in log_files:
try:
# Read and decompress if needed
@@ -182,7 +194,14 @@ class AuditLogger:
with gzip.open(log_file, "rt") as f:
for line in f:
event = self._parse_log_line(line.strip())
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
if self._matches_query(
event,
participant_id,
transaction_id,
event_type,
start_time,
end_time,
):
results.append(event)
if len(results) >= limit:
return results
@@ -190,75 +209,79 @@ class AuditLogger:
with open(log_file, "r") as f:
for line in f:
event = self._parse_log_line(line.strip())
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
if self._matches_query(
event,
participant_id,
transaction_id,
event_type,
start_time,
end_time,
):
results.append(event)
if len(results) >= limit:
return results
except Exception as e:
logger.error(f"Failed to read log file {log_file}: {e}")
continue
# Sort by timestamp (newest first)
results.sort(key=lambda x: x.timestamp, reverse=True)
return results[:limit]
def verify_integrity(self, start_date: Optional[datetime] = None) -> Dict[str, Any]:
"""Verify integrity of audit logs"""
if start_date is None:
start_date = datetime.utcnow() - timedelta(days=30)
results = {
"verified_files": 0,
"total_files": 0,
"integrity_violations": [],
"chain_valid": True
"chain_valid": True,
}
log_files = self._get_log_files(start_date)
for log_file in log_files:
results["total_files"] += 1
try:
# Verify file hash
file_hash = self._calculate_file_hash(log_file)
stored_hash = self._get_stored_hash(log_file)
if file_hash != stored_hash:
results["integrity_violations"].append({
"file": str(log_file),
"expected": stored_hash,
"actual": file_hash
})
results["integrity_violations"].append(
{
"file": str(log_file),
"expected": stored_hash,
"actual": file_hash,
}
)
results["chain_valid"] = False
else:
results["verified_files"] += 1
except Exception as e:
logger.error(f"Failed to verify {log_file}: {e}")
results["integrity_violations"].append({
"file": str(log_file),
"error": str(e)
})
results["integrity_violations"].append(
{"file": str(log_file), "error": str(e)}
)
results["chain_valid"] = False
return results
def export_logs(
self,
start_time: datetime,
end_time: datetime,
format: str = "json",
include_signatures: bool = True
include_signatures: bool = True,
) -> str:
"""Export audit logs for compliance reporting"""
events = self.query_logs(
start_time=start_time,
end_time=end_time,
limit=10000
)
events = self.query_logs(start_time=start_time, end_time=end_time, limit=10000)
if format == "json":
export_data = {
"export_metadata": {
@@ -266,39 +289,46 @@ class AuditLogger:
"end_time": end_time.isoformat(),
"event_count": len(events),
"exported_at": datetime.utcnow().isoformat(),
"include_signatures": include_signatures
"include_signatures": include_signatures,
},
"events": []
"events": [],
}
for event in events:
event_dict = asdict(event)
event_dict["timestamp"] = event.timestamp.isoformat()
if not include_signatures:
event_dict.pop("signature", None)
export_data["events"].append(event_dict)
return json.dumps(export_data, indent=2)
elif format == "csv":
import csv
import io
output = io.StringIO()
writer = csv.writer(output)
# Header
header = [
"event_id", "timestamp", "event_type", "participant_id",
"transaction_id", "action", "resource", "outcome",
"ip_address", "user_agent"
"event_id",
"timestamp",
"event_type",
"participant_id",
"transaction_id",
"action",
"resource",
"outcome",
"ip_address",
"user_agent",
]
if include_signatures:
header.append("signature")
writer.writerow(header)
# Events
for event in events:
row = [
@@ -311,17 +341,17 @@ class AuditLogger:
event.resource,
event.outcome,
event.ip_address,
event.user_agent
event.user_agent,
]
if include_signatures:
row.append(event.signature)
writer.writerow(row)
return output.getvalue()
else:
raise ValueError(f"Unsupported export format: {format}")
async def _background_writer(self):
"""Background task for writing audit events"""
while True:
@@ -332,51 +362,50 @@ class AuditLogger:
try:
# Use asyncio.wait_for for timeout
event = await asyncio.wait_for(
self.write_queue.get(),
timeout=1.0
self.write_queue.get(), timeout=1.0
)
events.append(event)
except asyncio.TimeoutError:
if events:
break
continue
# Write events
if events:
self._write_events(events)
except Exception as e:
logger.error(f"Background writer error: {e}")
# Brief pause to avoid error loops
await asyncio.sleep(1)
def _write_events(self, events: List[AuditEvent]):
"""Write events to current log file"""
try:
self._rotate_if_needed()
with open(self.current_file, "a") as f:
for event in events:
# Convert to JSON line
event_dict = asdict(event)
event_dict["timestamp"] = event.timestamp.isoformat()
# Write with signature
line = json.dumps(event_dict, separators=(",", ":")) + "\n"
f.write(line)
f.flush()
# Update chain hash
self._update_chain_hash(events[-1])
except Exception as e:
logger.error(f"Failed to write audit events: {e}")
def _rotate_if_needed(self):
"""Rotate log file if needed"""
now = datetime.utcnow()
today = now.date()
# Check if we need a new file
if self.current_file is None:
self._new_log_file(today)
@@ -384,31 +413,31 @@ class AuditLogger:
file_date = datetime.fromisoformat(
self.current_file.stem.split("_")[1]
).date()
if file_date != today:
self._new_log_file(today)
def _new_log_file(self, date):
"""Create new log file for date"""
filename = f"audit_{date.isoformat()}.log"
self.current_file = self.log_dir / filename
# Write header with metadata
if not self.current_file.exists():
header = {
"created_at": datetime.utcnow().isoformat(),
"version": "1.0",
"format": "jsonl",
"previous_hash": self.chain_hash
"previous_hash": self.chain_hash,
}
with open(self.current_file, "w") as f:
f.write(f"# {json.dumps(header)}\n")
def _generate_event_id(self) -> str:
"""Generate unique event ID"""
return f"evt_{datetime.utcnow().timestamp()}_{os.urandom(4).hex()}"
def _sign_event(self, event: AuditEvent) -> str:
"""Sign event for tamper-evidence"""
# Create canonical representation
@@ -417,24 +446,24 @@ class AuditLogger:
"timestamp": event.timestamp.isoformat(),
"participant_id": event.participant_id,
"action": event.action,
"outcome": event.outcome
"outcome": event.outcome,
}
# Hash with previous chain hash
data = json.dumps(event_data, separators=(",", ":"), sort_keys=True)
combined = f"{self.chain_hash}:{data}".encode()
return hashlib.sha256(combined).hexdigest()
def _update_chain_hash(self, last_event: AuditEvent):
"""Update chain hash with new event"""
self.chain_hash = last_event.signature or self.chain_hash
# Store chain hash for integrity checking
chain_file = self.log_dir / "chain.hash"
with open(chain_file, "w") as f:
f.write(self.chain_hash)
def _load_chain_hash(self) -> str:
"""Load previous chain hash"""
chain_file = self.log_dir / "chain.hash"
@@ -442,35 +471,38 @@ class AuditLogger:
with open(chain_file, "r") as f:
return f.read().strip()
return "0" * 64 # Initial hash
def _get_log_files(self, start_time: Optional[datetime], end_time: Optional[datetime]) -> List[Path]:
def _get_log_files(
self, start_time: Optional[datetime], end_time: Optional[datetime]
) -> List[Path]:
"""Get list of log files to search"""
files = []
for file in self.log_dir.glob("audit_*.log*"):
try:
# Extract date from filename
date_str = file.stem.split("_")[1]
file_date = datetime.fromisoformat(date_str).date()
# Check if file is in range
file_start = datetime.combine(file_date, datetime.min.time())
file_end = file_start + timedelta(days=1)
if (not start_time or file_end >= start_time) and \
(not end_time or file_start <= end_time):
if (not start_time or file_end >= start_time) and (
not end_time or file_start <= end_time
):
files.append(file)
except Exception:
continue
return sorted(files)
def _parse_log_line(self, line: str) -> Optional[AuditEvent]:
"""Parse log line into event"""
if line.startswith("#"):
return None # Skip header
try:
data = json.loads(line)
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
@@ -478,7 +510,7 @@ class AuditLogger:
except Exception as e:
logger.error(f"Failed to parse log line: {e}")
return None
def _matches_query(
self,
event: Optional[AuditEvent],
@@ -486,39 +518,39 @@ class AuditLogger:
transaction_id: Optional[str],
event_type: Optional[str],
start_time: Optional[datetime],
end_time: Optional[datetime]
end_time: Optional[datetime],
) -> bool:
"""Check if event matches query criteria"""
if not event:
return False
if participant_id and event.participant_id != participant_id:
return False
if transaction_id and event.transaction_id != transaction_id:
return False
if event_type and event.event_type != event_type:
return False
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
def _calculate_file_hash(self, file_path: Path) -> str:
"""Calculate SHA-256 hash of file"""
hash_sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def _get_stored_hash(self, file_path: Path) -> str:
"""Get stored hash for file"""
hash_file = file_path.with_suffix(".hash")

View File

@@ -0,0 +1,80 @@
"""
Confidential Transaction Service - Wrapper for existing confidential functionality
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from ..services.encryption import EncryptionService
from ..services.key_management import KeyManager
from ..models.confidential import ConfidentialTransaction, ViewingKey
class ConfidentialTransactionService:
"""Service for handling confidential transactions using existing encryption and key management"""
def __init__(self):
self.encryption_service = EncryptionService()
self.key_manager = KeyManager()
def create_confidential_transaction(
self,
sender: str,
recipient: str,
amount: int,
viewing_key: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> ConfidentialTransaction:
"""Create a new confidential transaction"""
# Generate viewing key if not provided
if not viewing_key:
viewing_key = self.key_manager.generate_viewing_key()
# Encrypt transaction data
encrypted_data = self.encryption_service.encrypt_transaction_data({
"sender": sender,
"recipient": recipient,
"amount": amount,
"metadata": metadata or {}
})
return ConfidentialTransaction(
sender=sender,
recipient=recipient,
encrypted_payload=encrypted_data,
viewing_key=viewing_key,
created_at=datetime.utcnow()
)
def decrypt_transaction(
self,
transaction: ConfidentialTransaction,
viewing_key: str
) -> Dict[str, Any]:
"""Decrypt a confidential transaction using viewing key"""
return self.encryption_service.decrypt_transaction_data(
transaction.encrypted_payload,
viewing_key
)
def verify_transaction_access(
self,
transaction: ConfidentialTransaction,
requester: str
) -> bool:
"""Verify if requester has access to view transaction"""
return requester in [transaction.sender, transaction.recipient]
def get_transaction_summary(
self,
transaction: ConfidentialTransaction,
viewer: str
) -> Dict[str, Any]:
"""Get transaction summary based on viewer permissions"""
if self.verify_transaction_access(transaction, viewer):
return self.decrypt_transaction(transaction, transaction.viewing_key)
else:
return {
"transaction_id": transaction.id,
"encrypted": True,
"accessible": False
}

View File

@@ -11,10 +11,18 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
from cryptography.hazmat.primitives.asymmetric.x25519 import (
X25519PrivateKey,
X25519PublicKey,
)
from cryptography.hazmat.primitives.serialization import (
Encoding,
PublicFormat,
PrivateFormat,
NoEncryption,
)
from ..schemas import ConfidentialTransaction, AccessLog
from ..schemas import ConfidentialTransaction, ConfidentialAccessLog
from ..config import settings
from ..logging import get_logger
@@ -23,21 +31,21 @@ logger = get_logger(__name__)
class EncryptedData:
"""Container for encrypted data and keys"""
def __init__(
self,
ciphertext: bytes,
encrypted_keys: Dict[str, bytes],
algorithm: str = "AES-256-GCM+X25519",
nonce: Optional[bytes] = None,
tag: Optional[bytes] = None
tag: Optional[bytes] = None,
):
self.ciphertext = ciphertext
self.encrypted_keys = encrypted_keys
self.algorithm = algorithm
self.nonce = nonce
self.tag = tag
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage"""
return {
@@ -48,9 +56,9 @@ class EncryptedData:
},
"algorithm": self.algorithm,
"nonce": base64.b64encode(self.nonce).decode() if self.nonce else None,
"tag": base64.b64encode(self.tag).decode() if self.tag else None
"tag": base64.b64encode(self.tag).decode() if self.tag else None,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "EncryptedData":
"""Create from dictionary"""
@@ -62,31 +70,28 @@ class EncryptedData:
},
algorithm=data["algorithm"],
nonce=base64.b64decode(data["nonce"]) if data.get("nonce") else None,
tag=base64.b64decode(data["tag"]) if data.get("tag") else None
tag=base64.b64decode(data["tag"]) if data.get("tag") else None,
)
class EncryptionService:
"""Service for encrypting/decrypting confidential transaction data"""
def __init__(self, key_manager: "KeyManager"):
self.key_manager = key_manager
self.backend = default_backend()
self.algorithm = "AES-256-GCM+X25519"
def encrypt(
self,
data: Dict[str, Any],
participants: List[str],
include_audit: bool = True
self, data: Dict[str, Any], participants: List[str], include_audit: bool = True
) -> EncryptedData:
"""Encrypt data for multiple participants
Args:
data: Data to encrypt
participants: List of participant IDs who can decrypt
include_audit: Whether to include audit escrow key
Returns:
EncryptedData container with ciphertext and encrypted keys
"""
@@ -94,16 +99,16 @@ class EncryptionService:
# Generate random DEK (Data Encryption Key)
dek = os.urandom(32) # 256-bit key for AES-256
nonce = os.urandom(12) # 96-bit nonce for GCM
# Serialize and encrypt data
plaintext = json.dumps(data, separators=(",", ":")).encode()
aesgcm = AESGCM(dek)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
# Extract tag (included in ciphertext for GCM)
tag = ciphertext[-16:]
actual_ciphertext = ciphertext[:-16]
# Encrypt DEK for each participant
encrypted_keys = {}
for participant in participants:
@@ -112,9 +117,11 @@ class EncryptionService:
encrypted_dek = self._encrypt_dek(dek, public_key)
encrypted_keys[participant] = encrypted_dek
except Exception as e:
logger.error(f"Failed to encrypt DEK for participant {participant}: {e}")
logger.error(
f"Failed to encrypt DEK for participant {participant}: {e}"
)
continue
# Add audit escrow if requested
if include_audit:
try:
@@ -123,67 +130,67 @@ class EncryptionService:
encrypted_keys["audit"] = encrypted_dek
except Exception as e:
logger.error(f"Failed to encrypt DEK for audit: {e}")
return EncryptedData(
ciphertext=actual_ciphertext,
encrypted_keys=encrypted_keys,
algorithm=self.algorithm,
nonce=nonce,
tag=tag
tag=tag,
)
except Exception as e:
logger.error(f"Encryption failed: {e}")
raise EncryptionError(f"Failed to encrypt data: {e}")
def decrypt(
self,
encrypted_data: EncryptedData,
participant_id: str,
purpose: str = "access"
purpose: str = "access",
) -> Dict[str, Any]:
"""Decrypt data for a specific participant
Args:
encrypted_data: The encrypted data container
participant_id: ID of the participant requesting decryption
purpose: Purpose of decryption for audit logging
Returns:
Decrypted data as dictionary
"""
try:
# Get participant's private key
private_key = self.key_manager.get_private_key(participant_id)
# Get encrypted DEK for participant
if participant_id not in encrypted_data.encrypted_keys:
raise AccessDeniedError(f"Participant {participant_id} not authorized")
encrypted_dek = encrypted_data.encrypted_keys[participant_id]
# Decrypt DEK
dek = self._decrypt_dek(encrypted_dek, private_key)
# Reconstruct ciphertext with tag
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
# Decrypt data
aesgcm = AESGCM(dek)
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
data = json.loads(plaintext.decode())
# Log access
self._log_access(
transaction_id=None, # Will be set by caller
participant_id=participant_id,
purpose=purpose,
success=True
success=True,
)
return data
except Exception as e:
logger.error(f"Decryption failed for participant {participant_id}: {e}")
self._log_access(
@@ -191,23 +198,23 @@ class EncryptionService:
participant_id=participant_id,
purpose=purpose,
success=False,
error=str(e)
error=str(e),
)
raise DecryptionError(f"Failed to decrypt data: {e}")
def audit_decrypt(
self,
encrypted_data: EncryptedData,
audit_authorization: str,
purpose: str = "audit"
purpose: str = "audit",
) -> Dict[str, Any]:
"""Decrypt data for audit purposes
Args:
encrypted_data: The encrypted data container
audit_authorization: Authorization token for audit access
purpose: Purpose of decryption
Returns:
Decrypted data as dictionary
"""
@@ -215,97 +222,101 @@ class EncryptionService:
# Verify audit authorization
if not self.key_manager.verify_audit_authorization(audit_authorization):
raise AccessDeniedError("Invalid audit authorization")
# Get audit private key
audit_private_key = self.key_manager.get_audit_private_key(audit_authorization)
audit_private_key = self.key_manager.get_audit_private_key(
audit_authorization
)
# Decrypt using audit key
if "audit" not in encrypted_data.encrypted_keys:
raise AccessDeniedError("Audit escrow not available")
encrypted_dek = encrypted_data.encrypted_keys["audit"]
dek = self._decrypt_dek(encrypted_dek, audit_private_key)
# Decrypt data
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
aesgcm = AESGCM(dek)
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
data = json.loads(plaintext.decode())
# Log audit access
self._log_access(
transaction_id=None,
participant_id="audit",
purpose=f"audit:{purpose}",
success=True,
authorization=audit_authorization
authorization=audit_authorization,
)
return data
except Exception as e:
logger.error(f"Audit decryption failed: {e}")
raise DecryptionError(f"Failed to decrypt for audit: {e}")
def _encrypt_dek(self, dek: bytes, public_key: X25519PublicKey) -> bytes:
"""Encrypt DEK using ECIES with X25519"""
# Generate ephemeral key pair
ephemeral_private = X25519PrivateKey.generate()
ephemeral_public = ephemeral_private.public_key()
# Perform ECDH
shared_key = ephemeral_private.exchange(public_key)
# Derive encryption key from shared secret
derived_key = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b"AITBC-DEK-Encryption",
backend=self.backend
backend=self.backend,
).derive(shared_key)
# Encrypt DEK with AES-GCM
aesgcm = AESGCM(derived_key)
nonce = os.urandom(12)
encrypted_dek = aesgcm.encrypt(nonce, dek, None)
# Return ephemeral public key + nonce + encrypted DEK
return (
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw) +
nonce +
encrypted_dek
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw)
+ nonce
+ encrypted_dek
)
def _decrypt_dek(self, encrypted_dek: bytes, private_key: X25519PrivateKey) -> bytes:
def _decrypt_dek(
self, encrypted_dek: bytes, private_key: X25519PrivateKey
) -> bytes:
"""Decrypt DEK using ECIES with X25519"""
# Extract components
ephemeral_public_bytes = encrypted_dek[:32]
nonce = encrypted_dek[32:44]
dek_ciphertext = encrypted_dek[44:]
# Reconstruct ephemeral public key
ephemeral_public = X25519PublicKey.from_public_bytes(ephemeral_public_bytes)
# Perform ECDH
shared_key = private_key.exchange(ephemeral_public)
# Derive decryption key
derived_key = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b"AITBC-DEK-Encryption",
backend=self.backend
backend=self.backend,
).derive(shared_key)
# Decrypt DEK
aesgcm = AESGCM(derived_key)
dek = aesgcm.decrypt(nonce, dek_ciphertext, None)
return dek
def _log_access(
self,
transaction_id: Optional[str],
@@ -313,7 +324,7 @@ class EncryptionService:
purpose: str,
success: bool,
error: Optional[str] = None,
authorization: Optional[str] = None
authorization: Optional[str] = None,
):
"""Log access to confidential data"""
try:
@@ -324,26 +335,29 @@ class EncryptionService:
"timestamp": datetime.utcnow().isoformat(),
"success": success,
"error": error,
"authorization": authorization
"authorization": authorization,
}
# In production, this would go to secure audit log
logger.info(f"Confidential data access: {json.dumps(log_entry)}")
except Exception as e:
logger.error(f"Failed to log access: {e}")
class EncryptionError(Exception):
"""Base exception for encryption errors"""
pass
class DecryptionError(EncryptionError):
"""Exception for decryption errors"""
pass
class AccessDeniedError(EncryptionError):
"""Exception for access denied errors"""
pass

View File

@@ -7,6 +7,7 @@ from typing import Optional
from sqlmodel import Session, select
from ..config import settings
from ..domain import Job, JobReceipt
from ..schemas import (
BlockListResponse,
@@ -39,29 +40,45 @@ class ExplorerService:
self.session = session
def list_blocks(self, *, limit: int = 20, offset: int = 0) -> BlockListResponse:
# Fetch real blockchain data from RPC API
# Fetch real blockchain data via /rpc/head and /rpc/blocks-range
rpc_base = settings.blockchain_rpc_url.rstrip("/")
try:
# Use the blockchain RPC API running on localhost:8082
with httpx.Client(timeout=10.0) as client:
response = client.get("http://localhost:8082/rpc/blocks", params={"limit": limit, "offset": offset})
response.raise_for_status()
rpc_data = response.json()
head_resp = client.get(f"{rpc_base}/rpc/head")
if head_resp.status_code == 404:
return BlockListResponse(items=[], next_offset=None)
head_resp.raise_for_status()
head = head_resp.json()
height = head.get("height", 0)
start = max(0, height - offset - limit + 1)
end = height - offset
if start > end:
return BlockListResponse(items=[], next_offset=None)
range_resp = client.get(
f"{rpc_base}/rpc/blocks-range",
params={"start": start, "end": end},
)
range_resp.raise_for_status()
rpc_data = range_resp.json()
raw_blocks = rpc_data.get("blocks", [])
# Node returns ascending by height; explorer expects newest first
raw_blocks = list(reversed(raw_blocks))
items: list[BlockSummary] = []
for block in rpc_data.get("blocks", []):
for block in raw_blocks:
ts = block.get("timestamp")
if isinstance(ts, str):
ts = datetime.fromisoformat(ts.replace("Z", "+00:00"))
items.append(
BlockSummary(
height=block["height"],
hash=block["hash"],
timestamp=datetime.fromisoformat(block["timestamp"]),
txCount=block["tx_count"],
proposer=block["proposer"],
timestamp=ts,
txCount=block.get("tx_count", 0),
proposer=block.get("proposer", ""),
)
)
next_offset: Optional[int] = offset + len(items) if len(items) == limit else None
next_offset = offset + len(items) if len(items) == limit else None
return BlockListResponse(items=items, next_offset=next_offset)
except Exception as e:
# Fallback to fake data if RPC is unavailable
print(f"Warning: Failed to fetch blocks from RPC: {e}, falling back to fake data")

View File

@@ -1,6 +1,8 @@
"""Ensure coordinator-api src is on sys.path for all tests in this directory."""
import sys
import os
import tempfile
from pathlib import Path
_src = str(Path(__file__).resolve().parent.parent / "src")
@@ -15,3 +17,9 @@ if _app_mod and hasattr(_app_mod, "__file__") and _app_mod.__file__ and _src not
if _src not in sys.path:
sys.path.insert(0, _src)
# Set up test environment
os.environ["TEST_MODE"] = "true"
project_root = Path(__file__).resolve().parent.parent.parent
os.environ["AUDIT_LOG_DIR"] = str(project_root / "logs" / "audit")
os.environ["TEST_DATABASE_URL"] = "sqlite:///:memory:"

View File

@@ -1,5 +1,5 @@
import pytest
from sqlmodel import Session, delete
from sqlmodel import Session, delete, text
from app.domain import Job, Miner
from app.models import JobCreate
@@ -14,7 +14,26 @@ def _init_db(tmp_path_factory):
from app.config import settings
settings.database_url = f"sqlite:///{db_file}"
# Initialize database and create tables
init_db()
# Ensure payment_id column exists (handle schema migration)
with session_scope() as sess:
try:
# Check if columns exist and add them if needed
result = sess.exec(text("PRAGMA table_info(job)"))
columns = [row[1] for row in result.fetchall()]
if 'payment_id' not in columns:
sess.exec(text("ALTER TABLE job ADD COLUMN payment_id TEXT"))
if 'payment_status' not in columns:
sess.exec(text("ALTER TABLE job ADD COLUMN payment_status TEXT"))
sess.commit()
except Exception as e:
print(f"Schema migration error: {e}")
sess.rollback()
yield

View File

@@ -9,19 +9,18 @@ from pathlib import Path
from app.services.zk_proofs import ZKProofService
from app.models import JobReceipt, Job, JobResult
from app.domain import ReceiptPayload
class TestZKProofService:
"""Test cases for ZK proof service"""
@pytest.fixture
def zk_service(self):
"""Create ZK proof service instance"""
with patch('app.services.zk_proofs.settings'):
with patch("app.services.zk_proofs.settings"):
service = ZKProofService()
return service
@pytest.fixture
def sample_job(self):
"""Create sample job for testing"""
@@ -31,9 +30,9 @@ class TestZKProofService:
payload={"type": "test"},
constraints={},
requested_at=None,
completed=True
completed=True,
)
@pytest.fixture
def sample_job_result(self):
"""Create sample job result"""
@@ -42,9 +41,9 @@ class TestZKProofService:
"result_hash": "0x1234567890abcdef",
"units": 100,
"unit_type": "gpu_seconds",
"metrics": {"execution_time": 5.0}
"metrics": {"execution_time": 5.0},
}
@pytest.fixture
def sample_receipt(self, sample_job):
"""Create sample receipt"""
@@ -59,171 +58,187 @@ class TestZKProofService:
price="0.1",
started_at=1640995200,
completed_at=1640995800,
metadata={}
metadata={},
)
return JobReceipt(
job_id=sample_job.id,
receipt_id=payload.receipt_id,
payload=payload.dict()
job_id=sample_job.id, receipt_id=payload.receipt_id, payload=payload.dict()
)
def test_service_initialization_with_files(self):
"""Test service initialization when circuit files exist"""
with patch('app.services.zk_proofs.Path') as mock_path:
with patch("app.services.zk_proofs.Path") as mock_path:
# Mock file existence
mock_path.return_value.exists.return_value = True
service = ZKProofService()
assert service.enabled is True
def test_service_initialization_without_files(self):
"""Test service initialization when circuit files are missing"""
with patch('app.services.zk_proofs.Path') as mock_path:
with patch("app.services.zk_proofs.Path") as mock_path:
# Mock file non-existence
mock_path.return_value.exists.return_value = False
service = ZKProofService()
assert service.enabled is False
@pytest.mark.asyncio
async def test_generate_proof_basic_privacy(self, zk_service, sample_receipt, sample_job_result):
async def test_generate_proof_basic_privacy(
self, zk_service, sample_receipt, sample_job_result
):
"""Test generating proof with basic privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
# Mock subprocess calls
with patch('subprocess.run') as mock_run:
with patch("subprocess.run") as mock_run:
# Mock successful proof generation
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = json.dumps({
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
"publicSignals": ["0x1234", "1000", "1640995800"]
})
mock_run.return_value.stdout = json.dumps(
{
"proof": {
"a": ["1", "2"],
"b": [["1", "2"], ["1", "2"]],
"c": ["1", "2"],
},
"publicSignals": ["0x1234", "1000", "1640995800"],
}
)
# Generate proof
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="basic"
privacy_level="basic",
)
assert proof is not None
assert "proof" in proof
assert "public_signals" in proof
assert proof["privacy_level"] == "basic"
assert "circuit_hash" in proof
@pytest.mark.asyncio
async def test_generate_proof_enhanced_privacy(self, zk_service, sample_receipt, sample_job_result):
async def test_generate_proof_enhanced_privacy(
self, zk_service, sample_receipt, sample_job_result
):
"""Test generating proof with enhanced privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run:
with patch("subprocess.run") as mock_run:
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = json.dumps({
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
"publicSignals": ["1000", "1640995800"]
})
mock_run.return_value.stdout = json.dumps(
{
"proof": {
"a": ["1", "2"],
"b": [["1", "2"], ["1", "2"]],
"c": ["1", "2"],
},
"publicSignals": ["1000", "1640995800"],
}
)
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="enhanced"
privacy_level="enhanced",
)
assert proof is not None
assert proof["privacy_level"] == "enhanced"
@pytest.mark.asyncio
async def test_generate_proof_service_disabled(self, zk_service, sample_receipt, sample_job_result):
async def test_generate_proof_service_disabled(
self, zk_service, sample_receipt, sample_job_result
):
"""Test proof generation when service is disabled"""
zk_service.enabled = False
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="basic"
receipt=sample_receipt, job_result=sample_job_result, privacy_level="basic"
)
assert proof is None
@pytest.mark.asyncio
async def test_generate_proof_invalid_privacy_level(self, zk_service, sample_receipt, sample_job_result):
async def test_generate_proof_invalid_privacy_level(
self, zk_service, sample_receipt, sample_job_result
):
"""Test proof generation with invalid privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with pytest.raises(ValueError, match="Unknown privacy level"):
await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="invalid"
privacy_level="invalid",
)
@pytest.mark.asyncio
async def test_verify_proof_success(self, zk_service):
"""Test successful proof verification"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run, \
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
with patch("subprocess.run") as mock_run, patch(
"builtins.open", mock_open(read_data='{"key": "value"}')
):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = "true"
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
public_signals=["0x1234", "1000"],
)
assert result is True
@pytest.mark.asyncio
async def test_verify_proof_failure(self, zk_service):
"""Test proof verification failure"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run, \
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
with patch("subprocess.run") as mock_run, patch(
"builtins.open", mock_open(read_data='{"key": "value"}')
):
mock_run.return_value.returncode = 1
mock_run.return_value.stderr = "Verification failed"
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
public_signals=["0x1234", "1000"],
)
assert result is False
@pytest.mark.asyncio
async def test_verify_proof_service_disabled(self, zk_service):
"""Test proof verification when service is disabled"""
zk_service.enabled = False
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
public_signals=["0x1234", "1000"],
)
assert result is False
def test_hash_receipt(self, zk_service, sample_receipt):
"""Test receipt hashing"""
receipt_hash = zk_service._hash_receipt(sample_receipt)
assert isinstance(receipt_hash, str)
assert len(receipt_hash) == 64 # SHA256 hex length
assert all(c in '0123456789abcdef' for c in receipt_hash)
assert all(c in "0123456789abcdef" for c in receipt_hash)
def test_serialize_receipt(self, zk_service, sample_receipt):
"""Test receipt serialization for circuit"""
serialized = zk_service._serialize_receipt(sample_receipt)
assert isinstance(serialized, list)
assert len(serialized) == 8
assert all(isinstance(x, str) for x in serialized)
@@ -231,19 +246,19 @@ class TestZKProofService:
class TestZKProofIntegration:
"""Integration tests for ZK proof system"""
@pytest.mark.asyncio
async def test_receipt_creation_with_zk_proof(self):
"""Test receipt creation with ZK proof generation"""
from app.services.receipts import ReceiptService
from sqlmodel import Session
# Create mock session
session = Mock(spec=Session)
# Create receipt service
receipt_service = ReceiptService(session)
# Create sample job
job = Job(
id="test-job-123",
@@ -251,43 +266,45 @@ class TestZKProofIntegration:
payload={"type": "test"},
constraints={},
requested_at=None,
completed=True
completed=True,
)
# Mock ZK proof service
with patch('app.services.receipts.zk_proof_service') as mock_zk:
with patch("app.services.receipts.zk_proof_service") as mock_zk:
mock_zk.is_enabled.return_value = True
mock_zk.generate_receipt_proof = AsyncMock(return_value={
"proof": {"a": ["1", "2"]},
"public_signals": ["0x1234"],
"privacy_level": "basic"
})
mock_zk.generate_receipt_proof = AsyncMock(
return_value={
"proof": {"a": ["1", "2"]},
"public_signals": ["0x1234"],
"privacy_level": "basic",
}
)
# Create receipt with privacy
receipt = await receipt_service.create_receipt(
job=job,
miner_id="miner-001",
job_result={"result": "test"},
result_metrics={"units": 100},
privacy_level="basic"
privacy_level="basic",
)
assert receipt is not None
assert "zk_proof" in receipt
assert receipt["privacy_level"] == "basic"
@pytest.mark.asyncio
async def test_settlement_with_zk_proof(self):
"""Test cross-chain settlement with ZK proof"""
from aitbc.settlement.hooks import SettlementHook
from aitbc.settlement.manager import BridgeManager
# Create mock bridge manager
bridge_manager = Mock(spec=BridgeManager)
# Create settlement hook
settlement_hook = SettlementHook(bridge_manager)
# Create sample job with ZK proof
job = Job(
id="test-job-123",
@@ -296,9 +313,9 @@ class TestZKProofIntegration:
constraints={},
requested_at=None,
completed=True,
target_chain=2
target_chain=2,
)
# Create receipt with ZK proof
receipt_payload = {
"version": "1.0",
@@ -306,24 +323,20 @@ class TestZKProofIntegration:
"job_id": job.id,
"provider": "miner-001",
"client": job.client_id,
"zk_proof": {
"proof": {"a": ["1", "2"]},
"public_signals": ["0x1234"]
}
"zk_proof": {"proof": {"a": ["1", "2"]}, "public_signals": ["0x1234"]},
}
job.receipt = JobReceipt(
job_id=job.id,
receipt_id=receipt_payload["receipt_id"],
payload=receipt_payload
payload=receipt_payload,
)
# Test settlement message creation
message = await settlement_hook._create_settlement_message(
job,
options={"use_zk_proof": True, "privacy_level": "basic"}
job, options={"use_zk_proof": True, "privacy_level": "basic"}
)
assert message.zk_proof is not None
assert message.privacy_level == "basic"
@@ -332,71 +345,70 @@ class TestZKProofIntegration:
def mock_open(read_data=""):
"""Mock open function for file operations"""
from unittest.mock import mock_open
return mock_open(read_data=read_data)
# Benchmark tests
class TestZKProofPerformance:
"""Performance benchmarks for ZK proof operations"""
@pytest.mark.asyncio
async def test_proof_generation_time(self):
"""Benchmark proof generation time"""
import time
if not Path("apps/zk-circuits/receipt.wasm").exists():
pytest.skip("ZK circuits not built")
service = ZKProofService()
if not service.enabled:
pytest.skip("ZK service not enabled")
# Create test data
receipt = JobReceipt(
job_id="benchmark-job",
receipt_id="benchmark-receipt",
payload={"test": "data"}
payload={"test": "data"},
)
job_result = {"result": "benchmark"}
# Measure proof generation time
start_time = time.time()
proof = await service.generate_receipt_proof(
receipt=receipt,
job_result=job_result,
privacy_level="basic"
receipt=receipt, job_result=job_result, privacy_level="basic"
)
end_time = time.time()
generation_time = end_time - start_time
assert proof is not None
assert generation_time < 30 # Should complete within 30 seconds
print(f"Proof generation time: {generation_time:.2f} seconds")
@pytest.mark.asyncio
async def test_proof_verification_time(self):
"""Benchmark proof verification time"""
import time
service = ZKProofService()
if not service.enabled:
pytest.skip("ZK service not enabled")
# Create test proof
proof = {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]}
public_signals = ["0x1234", "1000"]
# Measure verification time
start_time = time.time()
result = await service.verify_proof(proof, public_signals)
end_time = time.time()
verification_time = end_time - start_time
assert isinstance(result, bool)
assert verification_time < 1 # Should complete within 1 second
print(f"Proof verification time: {verification_time:.3f} seconds")