feat: add transaction hash search to blockchain explorer and cleanup settlement storage

Blockchain Explorer:
- Add transaction hash search support (64-char hex pattern validation)
- Fetch and display transaction details in modal (hash, type, from/to, amount, fee, block)
- Fix regex escape sequence in block height validation
- Update search placeholder text to mention both search types
- Add blank lines between function definitions for PEP 8 compliance

Settlement Storage:
- Add timedelta import for future
This commit is contained in:
oib
2026-02-17 14:34:12 +01:00
parent 31d3d70836
commit 421191ccaf
34 changed files with 2176 additions and 5660 deletions

View File

@@ -299,13 +299,67 @@ HTML_TEMPLATE = """
if (!query) return;
// Try block height first
if (/^\\d+$/.test(query)) {
if (/^\d+$/.test(query)) {
showBlockDetails(parseInt(query));
return;
}
// TODO: Add transaction hash search
alert('Search by block height is currently supported');
// Try transaction hash search (hex string, 64 chars)
if (/^[a-fA-F0-9]{64}$/.test(query)) {
try {
const tx = await fetch(`/api/transactions/${query}`).then(r => {
if (!r.ok) throw new Error('Transaction not found');
return r.json();
});
// Show transaction details - reuse block modal
const modal = document.getElementById('block-modal');
const details = document.getElementById('block-details');
details.innerHTML = `
<div class="space-y-6">
<div>
<h3 class="text-lg font-semibold mb-2">Transaction</h3>
<div class="bg-gray-50 rounded p-4 space-y-2">
<div class="flex justify-between">
<span class="text-gray-600">Hash:</span>
<span class="font-mono text-sm">${tx.hash || '-'}</span>
</div>
<div class="flex justify-between">
<span class="text-gray-600">Type:</span>
<span>${tx.type || '-'}</span>
</div>
<div class="flex justify-between">
<span class="text-gray-600">From:</span>
<span class="font-mono text-sm">${tx.from || '-'}</span>
</div>
<div class="flex justify-between">
<span class="text-gray-600">To:</span>
<span class="font-mono text-sm">${tx.to || '-'}</span>
</div>
<div class="flex justify-between">
<span class="text-gray-600">Amount:</span>
<span>${tx.amount || '0'}</span>
</div>
<div class="flex justify-between">
<span class="text-gray-600">Fee:</span>
<span>${tx.fee || '0'}</span>
</div>
<div class="flex justify-between">
<span class="text-gray-600">Block:</span>
<span>${tx.block_height || '-'}</span>
</div>
</div>
</div>
</div>
`;
modal.classList.remove('hidden');
return;
} catch (e) {
alert('Transaction not found');
return;
}
}
alert('Search by block height or transaction hash (64 char hex) is supported');
}
// Format timestamp
@@ -321,6 +375,7 @@ HTML_TEMPLATE = """
</html>
"""
async def get_chain_head() -> Dict[str, Any]:
"""Get the current chain head"""
try:
@@ -332,6 +387,7 @@ async def get_chain_head() -> Dict[str, Any]:
print(f"Error getting chain head: {e}")
return {}
async def get_block(height: int) -> Dict[str, Any]:
"""Get a specific block by height"""
try:
@@ -343,21 +399,25 @@ async def get_block(height: int) -> Dict[str, Any]:
print(f"Error getting block {height}: {e}")
return {}
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the explorer UI"""
return HTML_TEMPLATE.format(node_url=BLOCKCHAIN_RPC_URL)
@app.get("/api/chain/head")
async def api_chain_head():
"""API endpoint for chain head"""
return await get_chain_head()
@app.get("/api/blocks/{height}")
async def api_block(height: int):
"""API endpoint for block data"""
return await get_block(height)
@app.get("/health")
async def health():
"""Health check endpoint"""
@@ -365,8 +425,9 @@ async def health():
return {
"status": "ok" if head else "error",
"node_url": BLOCKCHAIN_RPC_URL,
"chain_height": head.get("height", 0)
"chain_height": head.get("height", 0),
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=3000)

View File

@@ -3,30 +3,26 @@ Storage layer for cross-chain settlements
"""
from typing import Dict, Any, Optional, List
from datetime import datetime
from datetime import datetime, timedelta
import json
import asyncio
from dataclasses import asdict
from .bridges.base import (
SettlementMessage,
SettlementResult,
BridgeStatus
)
from .bridges.base import SettlementMessage, SettlementResult, BridgeStatus
class SettlementStorage:
"""Storage interface for settlement data"""
def __init__(self, db_connection):
self.db = db_connection
async def store_settlement(
self,
message_id: str,
message: SettlementMessage,
bridge_name: str,
status: BridgeStatus
status: BridgeStatus,
) -> None:
"""Store a new settlement record"""
query = """
@@ -38,93 +34,96 @@ class SettlementStorage:
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
)
"""
await self.db.execute(query, (
message_id,
message.job_id,
message.source_chain_id,
message.target_chain_id,
message.receipt_hash,
json.dumps(message.proof_data),
message.payment_amount,
message.payment_token,
message.nonce,
message.signature,
bridge_name,
status.value,
message.created_at or datetime.utcnow()
))
await self.db.execute(
query,
(
message_id,
message.job_id,
message.source_chain_id,
message.target_chain_id,
message.receipt_hash,
json.dumps(message.proof_data),
message.payment_amount,
message.payment_token,
message.nonce,
message.signature,
bridge_name,
status.value,
message.created_at or datetime.utcnow(),
),
)
async def update_settlement(
self,
message_id: str,
status: Optional[BridgeStatus] = None,
transaction_hash: Optional[str] = None,
error_message: Optional[str] = None,
completed_at: Optional[datetime] = None
completed_at: Optional[datetime] = None,
) -> None:
"""Update settlement record"""
updates = []
params = []
param_count = 1
if status is not None:
updates.append(f"status = ${param_count}")
params.append(status.value)
param_count += 1
if transaction_hash is not None:
updates.append(f"transaction_hash = ${param_count}")
params.append(transaction_hash)
param_count += 1
if error_message is not None:
updates.append(f"error_message = ${param_count}")
params.append(error_message)
param_count += 1
if completed_at is not None:
updates.append(f"completed_at = ${param_count}")
params.append(completed_at)
param_count += 1
if not updates:
return
updates.append(f"updated_at = ${param_count}")
params.append(datetime.utcnow())
param_count += 1
params.append(message_id)
query = f"""
UPDATE settlements
SET {', '.join(updates)}
SET {", ".join(updates)}
WHERE message_id = ${param_count}
"""
await self.db.execute(query, params)
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
"""Get settlement by message ID"""
query = """
SELECT * FROM settlements WHERE message_id = $1
"""
result = await self.db.fetchrow(query, message_id)
if not result:
return None
# Convert to dict
settlement = dict(result)
# Parse JSON fields
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
if settlement["proof_data"]:
settlement["proof_data"] = json.loads(settlement["proof_data"])
return settlement
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
"""Get all settlements for a job"""
query = """
@@ -132,65 +131,67 @@ class SettlementStorage:
WHERE job_id = $1
ORDER BY created_at DESC
"""
results = await self.db.fetch(query, job_id)
settlements = []
for result in results:
settlement = dict(result)
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
if settlement["proof_data"]:
settlement["proof_data"] = json.loads(settlement["proof_data"])
settlements.append(settlement)
return settlements
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
async def get_pending_settlements(
self, bridge_name: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Get all pending settlements"""
query = """
SELECT * FROM settlements
WHERE status = 'pending' OR status = 'in_progress'
"""
params = []
if bridge_name:
query += " AND bridge_name = $1"
params.append(bridge_name)
query += " ORDER BY created_at ASC"
results = await self.db.fetch(query, *params)
settlements = []
for result in results:
settlement = dict(result)
if settlement['proof_data']:
settlement['proof_data'] = json.loads(settlement['proof_data'])
if settlement["proof_data"]:
settlement["proof_data"] = json.loads(settlement["proof_data"])
settlements.append(settlement)
return settlements
async def get_settlement_stats(
self,
bridge_name: Optional[str] = None,
time_range: Optional[int] = None # hours
time_range: Optional[int] = None, # hours
) -> Dict[str, Any]:
"""Get settlement statistics"""
conditions = []
params = []
param_count = 1
if bridge_name:
conditions.append(f"bridge_name = ${param_count}")
params.append(bridge_name)
param_count += 1
if time_range:
conditions.append(f"created_at > NOW() - INTERVAL '${param_count} hours'")
params.append(time_range)
param_count += 1
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
query = f"""
SELECT
bridge_name,
@@ -202,23 +203,27 @@ class SettlementStorage:
{where_clause}
GROUP BY bridge_name, status
"""
results = await self.db.fetch(query, *params)
stats = {}
for result in results:
bridge = result['bridge_name']
bridge = result["bridge_name"]
if bridge not in stats:
stats[bridge] = {}
stats[bridge][result['status']] = {
'count': result['count'],
'avg_amount': float(result['avg_amount']) if result['avg_amount'] else 0,
'total_amount': float(result['total_amount']) if result['total_amount'] else 0
stats[bridge][result["status"]] = {
"count": result["count"],
"avg_amount": float(result["avg_amount"])
if result["avg_amount"]
else 0,
"total_amount": float(result["total_amount"])
if result["total_amount"]
else 0,
}
return stats
async def cleanup_old_settlements(self, days: int = 30) -> int:
"""Clean up old completed settlements"""
query = """
@@ -226,7 +231,7 @@ class SettlementStorage:
WHERE status IN ('completed', 'failed')
AND created_at < NOW() - INTERVAL $1 days
"""
result = await self.db.execute(query, days)
return result.split()[-1] # Return number of deleted rows
@@ -234,134 +239,139 @@ class SettlementStorage:
# In-memory implementation for testing
class InMemorySettlementStorage(SettlementStorage):
"""In-memory storage implementation for testing"""
def __init__(self):
self.settlements: Dict[str, Dict[str, Any]] = {}
self._lock = asyncio.Lock()
async def store_settlement(
self,
message_id: str,
message: SettlementMessage,
bridge_name: str,
status: BridgeStatus
status: BridgeStatus,
) -> None:
async with self._lock:
self.settlements[message_id] = {
'message_id': message_id,
'job_id': message.job_id,
'source_chain_id': message.source_chain_id,
'target_chain_id': message.target_chain_id,
'receipt_hash': message.receipt_hash,
'proof_data': message.proof_data,
'payment_amount': message.payment_amount,
'payment_token': message.payment_token,
'nonce': message.nonce,
'signature': message.signature,
'bridge_name': bridge_name,
'status': status.value,
'created_at': message.created_at or datetime.utcnow(),
'updated_at': datetime.utcnow()
"message_id": message_id,
"job_id": message.job_id,
"source_chain_id": message.source_chain_id,
"target_chain_id": message.target_chain_id,
"receipt_hash": message.receipt_hash,
"proof_data": message.proof_data,
"payment_amount": message.payment_amount,
"payment_token": message.payment_token,
"nonce": message.nonce,
"signature": message.signature,
"bridge_name": bridge_name,
"status": status.value,
"created_at": message.created_at or datetime.utcnow(),
"updated_at": datetime.utcnow(),
}
async def update_settlement(
self,
message_id: str,
status: Optional[BridgeStatus] = None,
transaction_hash: Optional[str] = None,
error_message: Optional[str] = None,
completed_at: Optional[datetime] = None
completed_at: Optional[datetime] = None,
) -> None:
async with self._lock:
if message_id not in self.settlements:
return
settlement = self.settlements[message_id]
if status is not None:
settlement['status'] = status.value
settlement["status"] = status.value
if transaction_hash is not None:
settlement['transaction_hash'] = transaction_hash
settlement["transaction_hash"] = transaction_hash
if error_message is not None:
settlement['error_message'] = error_message
settlement["error_message"] = error_message
if completed_at is not None:
settlement['completed_at'] = completed_at
settlement['updated_at'] = datetime.utcnow()
settlement["completed_at"] = completed_at
settlement["updated_at"] = datetime.utcnow()
async def get_settlement(self, message_id: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self.settlements.get(message_id)
async def get_settlements_by_job(self, job_id: str) -> List[Dict[str, Any]]:
async with self._lock:
return [
s for s in self.settlements.values()
if s['job_id'] == job_id
]
async def get_pending_settlements(self, bridge_name: Optional[str] = None) -> List[Dict[str, Any]]:
return [s for s in self.settlements.values() if s["job_id"] == job_id]
async def get_pending_settlements(
self, bridge_name: Optional[str] = None
) -> List[Dict[str, Any]]:
async with self._lock:
pending = [
s for s in self.settlements.values()
if s['status'] in ['pending', 'in_progress']
s
for s in self.settlements.values()
if s["status"] in ["pending", "in_progress"]
]
if bridge_name:
pending = [s for s in pending if s['bridge_name'] == bridge_name]
pending = [s for s in pending if s["bridge_name"] == bridge_name]
return pending
async def get_settlement_stats(
self,
bridge_name: Optional[str] = None,
time_range: Optional[int] = None
self, bridge_name: Optional[str] = None, time_range: Optional[int] = None
) -> Dict[str, Any]:
async with self._lock:
stats = {}
for settlement in self.settlements.values():
if bridge_name and settlement['bridge_name'] != bridge_name:
if bridge_name and settlement["bridge_name"] != bridge_name:
continue
# TODO: Implement time range filtering
bridge = settlement['bridge_name']
# Time range filtering
if time_range is not None:
cutoff = datetime.utcnow() - timedelta(hours=time_range)
if settlement["created_at"] < cutoff:
continue
bridge = settlement["bridge_name"]
if bridge not in stats:
stats[bridge] = {}
status = settlement['status']
status = settlement["status"]
if status not in stats[bridge]:
stats[bridge][status] = {
'count': 0,
'avg_amount': 0,
'total_amount': 0
"count": 0,
"avg_amount": 0,
"total_amount": 0,
}
stats[bridge][status]['count'] += 1
stats[bridge][status]['total_amount'] += settlement['payment_amount']
stats[bridge][status]["count"] += 1
stats[bridge][status]["total_amount"] += settlement["payment_amount"]
# Calculate averages
for bridge_data in stats.values():
for status_data in bridge_data.values():
if status_data['count'] > 0:
status_data['avg_amount'] = status_data['total_amount'] / status_data['count']
if status_data["count"] > 0:
status_data["avg_amount"] = (
status_data["total_amount"] / status_data["count"]
)
return stats
async def cleanup_old_settlements(self, days: int = 30) -> int:
async with self._lock:
cutoff = datetime.utcnow() - timedelta(days=days)
to_delete = [
msg_id for msg_id, settlement in self.settlements.items()
msg_id
for msg_id, settlement in self.settlements.items()
if (
settlement['status'] in ['completed', 'failed'] and
settlement['created_at'] < cutoff
settlement["status"] in ["completed", "failed"]
and settlement["created_at"] < cutoff
)
]
for msg_id in to_delete:
del self.settlements[msg_id]
return len(to_delete)

View File

@@ -4,115 +4,137 @@ Unified configuration for AITBC Coordinator API
Provides environment-based adapter selection and consolidated settings.
"""
import os
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Optional
from pathlib import Path
import os
class DatabaseConfig(BaseSettings):
"""Database configuration with adapter selection."""
adapter: str = "sqlite" # sqlite, postgresql
url: Optional[str] = None
pool_size: int = 10
max_overflow: int = 20
pool_pre_ping: bool = True
@property
def effective_url(self) -> str:
"""Get the effective database URL."""
if self.url:
return self.url
# Default SQLite path
if self.adapter == "sqlite":
return "sqlite:///./coordinator.db"
# Default PostgreSQL connection string
return f"{self.adapter}://localhost:5432/coordinator"
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="allow"
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow"
)
class Settings(BaseSettings):
"""Unified application settings with environment-based configuration."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="allow"
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="allow"
)
# Environment
app_env: str = "dev"
app_host: str = "127.0.0.1"
app_port: int = 8011
audit_log_dir: str = "/var/log/aitbc/audit"
# Database
database: DatabaseConfig = DatabaseConfig()
# API Keys
client_api_keys: List[str] = []
miner_api_keys: List[str] = []
admin_api_keys: List[str] = []
# Security
hmac_secret: Optional[str] = None
jwt_secret: Optional[str] = None
jwt_algorithm: str = "HS256"
jwt_expiration_hours: int = 24
# CORS
allow_origins: List[str] = [
"http://localhost:3000",
"http://localhost:8080",
"http://localhost:8000",
"http://localhost:8011"
"http://localhost:8011",
]
# Job Configuration
job_ttl_seconds: int = 900
heartbeat_interval_seconds: int = 10
heartbeat_timeout_seconds: int = 30
# Rate Limiting
rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60
# Receipt Signing
receipt_signing_key_hex: Optional[str] = None
receipt_attestation_key_hex: Optional[str] = None
# Logging
log_level: str = "INFO"
log_format: str = "json" # json or text
# Mempool
mempool_backend: str = "database" # database, memory
# Blockchain RPC
blockchain_rpc_url: str = "http://localhost:8082"
# Test Configuration
test_mode: bool = False
test_database_url: Optional[str] = None
def validate_secrets(self) -> None:
"""Validate that all required secrets are provided."""
if self.app_env == "production":
if not self.jwt_secret:
raise ValueError("JWT_SECRET environment variable is required in production")
raise ValueError(
"JWT_SECRET environment variable is required in production"
)
if self.jwt_secret == "change-me-in-production":
raise ValueError("JWT_SECRET must be changed from default value")
@property
def database_url(self) -> str:
"""Get the database URL (backward compatibility)."""
# Use test database if in test mode and test_database_url is set
if self.test_mode and self.test_database_url:
return self.test_database_url
if self.database.url:
return self.database.url
# Default SQLite path for backward compatibility
return f"sqlite:///./aitbc_coordinator.db"
@database_url.setter
def database_url(self, value: str):
"""Allow setting database URL for tests"""
if not self.test_mode:
raise RuntimeError("Cannot set database_url outside of test mode")
self.test_database_url = value
settings = Settings()
# Enable test mode if environment variable is set
if os.getenv("TEST_MODE") == "true":
settings.test_mode = True
if os.getenv("TEST_DATABASE_URL"):
settings.test_database_url = os.getenv("TEST_DATABASE_URL")
# Validate secrets on import
settings.validate_secrets()

View File

@@ -52,6 +52,7 @@ from ..schemas import (
from ..domain import (
Job,
Miner,
JobReceipt,
MarketplaceOffer,
MarketplaceBid,
User,
@@ -93,6 +94,7 @@ __all__ = [
"Constraints",
"Job",
"Miner",
"JobReceipt",
"MarketplaceOffer",
"MarketplaceBid",
"ServiceType",

View File

@@ -22,6 +22,7 @@ logger = get_logger(__name__)
@dataclass
class AuditEvent:
"""Structured audit event"""
event_id: str
timestamp: datetime
event_type: str
@@ -39,27 +40,38 @@ class AuditEvent:
class AuditLogger:
"""Tamper-evident audit logging for privacy compliance"""
def __init__(self, log_dir: str = "/var/log/aitbc/audit"):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
def __init__(self, log_dir: str = None):
# Use test-specific directory if in test environment
if os.getenv("PYTEST_CURRENT_TEST"):
# Use project logs directory for tests
# Navigate from coordinator-api/src/app/services/audit_logging.py to project root
# Path: coordinator-api/src/app/services/audit_logging.py -> apps/coordinator-api/src -> apps/coordinator-api -> apps -> project_root
project_root = Path(__file__).resolve().parent.parent.parent.parent.parent.parent
test_log_dir = project_root / "logs" / "audit"
log_path = log_dir or str(test_log_dir)
else:
log_path = log_dir or settings.audit_log_dir
self.log_dir = Path(log_path)
self.log_dir.mkdir(parents=True, exist_ok=True)
# Current log file
self.current_file = None
self.current_hash = None
# Async writer task
self.write_queue = asyncio.Queue(maxsize=10000)
self.writer_task = None
# Chain of hashes for integrity
self.chain_hash = self._load_chain_hash()
async def start(self):
"""Start the background writer task"""
if self.writer_task is None:
self.writer_task = asyncio.create_task(self._background_writer())
async def stop(self):
"""Stop the background writer task"""
if self.writer_task:
@@ -69,7 +81,7 @@ class AuditLogger:
except asyncio.CancelledError:
pass
self.writer_task = None
async def log_access(
self,
participant_id: str,
@@ -79,7 +91,7 @@ class AuditLogger:
details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
authorization: Optional[str] = None
authorization: Optional[str] = None,
):
"""Log access to confidential data"""
event = AuditEvent(
@@ -95,22 +107,22 @@ class AuditLogger:
ip_address=ip_address,
user_agent=user_agent,
authorization=authorization,
signature=None
signature=None,
)
# Add signature for tamper-evidence
event.signature = self._sign_event(event)
# Queue for writing
await self.write_queue.put(event)
async def log_key_operation(
self,
participant_id: str,
operation: str,
key_version: int,
outcome: str,
details: Optional[Dict[str, Any]] = None
details: Optional[Dict[str, Any]] = None,
):
"""Log key management operations"""
event = AuditEvent(
@@ -126,19 +138,19 @@ class AuditLogger:
ip_address=None,
user_agent=None,
authorization=None,
signature=None
signature=None,
)
event.signature = self._sign_event(event)
await self.write_queue.put(event)
async def log_policy_change(
self,
participant_id: str,
policy_id: str,
change_type: str,
outcome: str,
details: Optional[Dict[str, Any]] = None
details: Optional[Dict[str, Any]] = None,
):
"""Log access policy changes"""
event = AuditEvent(
@@ -154,12 +166,12 @@ class AuditLogger:
ip_address=None,
user_agent=None,
authorization=None,
signature=None
signature=None,
)
event.signature = self._sign_event(event)
await self.write_queue.put(event)
def query_logs(
self,
participant_id: Optional[str] = None,
@@ -167,14 +179,14 @@ class AuditLogger:
event_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 100
limit: int = 100,
) -> List[AuditEvent]:
"""Query audit logs"""
results = []
# Get list of log files to search
log_files = self._get_log_files(start_time, end_time)
for log_file in log_files:
try:
# Read and decompress if needed
@@ -182,7 +194,14 @@ class AuditLogger:
with gzip.open(log_file, "rt") as f:
for line in f:
event = self._parse_log_line(line.strip())
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
if self._matches_query(
event,
participant_id,
transaction_id,
event_type,
start_time,
end_time,
):
results.append(event)
if len(results) >= limit:
return results
@@ -190,75 +209,79 @@ class AuditLogger:
with open(log_file, "r") as f:
for line in f:
event = self._parse_log_line(line.strip())
if self._matches_query(event, participant_id, transaction_id, event_type, start_time, end_time):
if self._matches_query(
event,
participant_id,
transaction_id,
event_type,
start_time,
end_time,
):
results.append(event)
if len(results) >= limit:
return results
except Exception as e:
logger.error(f"Failed to read log file {log_file}: {e}")
continue
# Sort by timestamp (newest first)
results.sort(key=lambda x: x.timestamp, reverse=True)
return results[:limit]
def verify_integrity(self, start_date: Optional[datetime] = None) -> Dict[str, Any]:
"""Verify integrity of audit logs"""
if start_date is None:
start_date = datetime.utcnow() - timedelta(days=30)
results = {
"verified_files": 0,
"total_files": 0,
"integrity_violations": [],
"chain_valid": True
"chain_valid": True,
}
log_files = self._get_log_files(start_date)
for log_file in log_files:
results["total_files"] += 1
try:
# Verify file hash
file_hash = self._calculate_file_hash(log_file)
stored_hash = self._get_stored_hash(log_file)
if file_hash != stored_hash:
results["integrity_violations"].append({
"file": str(log_file),
"expected": stored_hash,
"actual": file_hash
})
results["integrity_violations"].append(
{
"file": str(log_file),
"expected": stored_hash,
"actual": file_hash,
}
)
results["chain_valid"] = False
else:
results["verified_files"] += 1
except Exception as e:
logger.error(f"Failed to verify {log_file}: {e}")
results["integrity_violations"].append({
"file": str(log_file),
"error": str(e)
})
results["integrity_violations"].append(
{"file": str(log_file), "error": str(e)}
)
results["chain_valid"] = False
return results
def export_logs(
self,
start_time: datetime,
end_time: datetime,
format: str = "json",
include_signatures: bool = True
include_signatures: bool = True,
) -> str:
"""Export audit logs for compliance reporting"""
events = self.query_logs(
start_time=start_time,
end_time=end_time,
limit=10000
)
events = self.query_logs(start_time=start_time, end_time=end_time, limit=10000)
if format == "json":
export_data = {
"export_metadata": {
@@ -266,39 +289,46 @@ class AuditLogger:
"end_time": end_time.isoformat(),
"event_count": len(events),
"exported_at": datetime.utcnow().isoformat(),
"include_signatures": include_signatures
"include_signatures": include_signatures,
},
"events": []
"events": [],
}
for event in events:
event_dict = asdict(event)
event_dict["timestamp"] = event.timestamp.isoformat()
if not include_signatures:
event_dict.pop("signature", None)
export_data["events"].append(event_dict)
return json.dumps(export_data, indent=2)
elif format == "csv":
import csv
import io
output = io.StringIO()
writer = csv.writer(output)
# Header
header = [
"event_id", "timestamp", "event_type", "participant_id",
"transaction_id", "action", "resource", "outcome",
"ip_address", "user_agent"
"event_id",
"timestamp",
"event_type",
"participant_id",
"transaction_id",
"action",
"resource",
"outcome",
"ip_address",
"user_agent",
]
if include_signatures:
header.append("signature")
writer.writerow(header)
# Events
for event in events:
row = [
@@ -311,17 +341,17 @@ class AuditLogger:
event.resource,
event.outcome,
event.ip_address,
event.user_agent
event.user_agent,
]
if include_signatures:
row.append(event.signature)
writer.writerow(row)
return output.getvalue()
else:
raise ValueError(f"Unsupported export format: {format}")
async def _background_writer(self):
"""Background task for writing audit events"""
while True:
@@ -332,51 +362,50 @@ class AuditLogger:
try:
# Use asyncio.wait_for for timeout
event = await asyncio.wait_for(
self.write_queue.get(),
timeout=1.0
self.write_queue.get(), timeout=1.0
)
events.append(event)
except asyncio.TimeoutError:
if events:
break
continue
# Write events
if events:
self._write_events(events)
except Exception as e:
logger.error(f"Background writer error: {e}")
# Brief pause to avoid error loops
await asyncio.sleep(1)
def _write_events(self, events: List[AuditEvent]):
"""Write events to current log file"""
try:
self._rotate_if_needed()
with open(self.current_file, "a") as f:
for event in events:
# Convert to JSON line
event_dict = asdict(event)
event_dict["timestamp"] = event.timestamp.isoformat()
# Write with signature
line = json.dumps(event_dict, separators=(",", ":")) + "\n"
f.write(line)
f.flush()
# Update chain hash
self._update_chain_hash(events[-1])
except Exception as e:
logger.error(f"Failed to write audit events: {e}")
def _rotate_if_needed(self):
"""Rotate log file if needed"""
now = datetime.utcnow()
today = now.date()
# Check if we need a new file
if self.current_file is None:
self._new_log_file(today)
@@ -384,31 +413,31 @@ class AuditLogger:
file_date = datetime.fromisoformat(
self.current_file.stem.split("_")[1]
).date()
if file_date != today:
self._new_log_file(today)
def _new_log_file(self, date):
"""Create new log file for date"""
filename = f"audit_{date.isoformat()}.log"
self.current_file = self.log_dir / filename
# Write header with metadata
if not self.current_file.exists():
header = {
"created_at": datetime.utcnow().isoformat(),
"version": "1.0",
"format": "jsonl",
"previous_hash": self.chain_hash
"previous_hash": self.chain_hash,
}
with open(self.current_file, "w") as f:
f.write(f"# {json.dumps(header)}\n")
def _generate_event_id(self) -> str:
"""Generate unique event ID"""
return f"evt_{datetime.utcnow().timestamp()}_{os.urandom(4).hex()}"
def _sign_event(self, event: AuditEvent) -> str:
"""Sign event for tamper-evidence"""
# Create canonical representation
@@ -417,24 +446,24 @@ class AuditLogger:
"timestamp": event.timestamp.isoformat(),
"participant_id": event.participant_id,
"action": event.action,
"outcome": event.outcome
"outcome": event.outcome,
}
# Hash with previous chain hash
data = json.dumps(event_data, separators=(",", ":"), sort_keys=True)
combined = f"{self.chain_hash}:{data}".encode()
return hashlib.sha256(combined).hexdigest()
def _update_chain_hash(self, last_event: AuditEvent):
"""Update chain hash with new event"""
self.chain_hash = last_event.signature or self.chain_hash
# Store chain hash for integrity checking
chain_file = self.log_dir / "chain.hash"
with open(chain_file, "w") as f:
f.write(self.chain_hash)
def _load_chain_hash(self) -> str:
"""Load previous chain hash"""
chain_file = self.log_dir / "chain.hash"
@@ -442,35 +471,38 @@ class AuditLogger:
with open(chain_file, "r") as f:
return f.read().strip()
return "0" * 64 # Initial hash
def _get_log_files(self, start_time: Optional[datetime], end_time: Optional[datetime]) -> List[Path]:
def _get_log_files(
self, start_time: Optional[datetime], end_time: Optional[datetime]
) -> List[Path]:
"""Get list of log files to search"""
files = []
for file in self.log_dir.glob("audit_*.log*"):
try:
# Extract date from filename
date_str = file.stem.split("_")[1]
file_date = datetime.fromisoformat(date_str).date()
# Check if file is in range
file_start = datetime.combine(file_date, datetime.min.time())
file_end = file_start + timedelta(days=1)
if (not start_time or file_end >= start_time) and \
(not end_time or file_start <= end_time):
if (not start_time or file_end >= start_time) and (
not end_time or file_start <= end_time
):
files.append(file)
except Exception:
continue
return sorted(files)
def _parse_log_line(self, line: str) -> Optional[AuditEvent]:
"""Parse log line into event"""
if line.startswith("#"):
return None # Skip header
try:
data = json.loads(line)
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
@@ -478,7 +510,7 @@ class AuditLogger:
except Exception as e:
logger.error(f"Failed to parse log line: {e}")
return None
def _matches_query(
self,
event: Optional[AuditEvent],
@@ -486,39 +518,39 @@ class AuditLogger:
transaction_id: Optional[str],
event_type: Optional[str],
start_time: Optional[datetime],
end_time: Optional[datetime]
end_time: Optional[datetime],
) -> bool:
"""Check if event matches query criteria"""
if not event:
return False
if participant_id and event.participant_id != participant_id:
return False
if transaction_id and event.transaction_id != transaction_id:
return False
if event_type and event.event_type != event_type:
return False
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
def _calculate_file_hash(self, file_path: Path) -> str:
"""Calculate SHA-256 hash of file"""
hash_sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def _get_stored_hash(self, file_path: Path) -> str:
"""Get stored hash for file"""
hash_file = file_path.with_suffix(".hash")

View File

@@ -0,0 +1,80 @@
"""
Confidential Transaction Service - Wrapper for existing confidential functionality
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from ..services.encryption import EncryptionService
from ..services.key_management import KeyManager
from ..models.confidential import ConfidentialTransaction, ViewingKey
class ConfidentialTransactionService:
"""Service for handling confidential transactions using existing encryption and key management"""
def __init__(self):
self.encryption_service = EncryptionService()
self.key_manager = KeyManager()
def create_confidential_transaction(
self,
sender: str,
recipient: str,
amount: int,
viewing_key: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> ConfidentialTransaction:
"""Create a new confidential transaction"""
# Generate viewing key if not provided
if not viewing_key:
viewing_key = self.key_manager.generate_viewing_key()
# Encrypt transaction data
encrypted_data = self.encryption_service.encrypt_transaction_data({
"sender": sender,
"recipient": recipient,
"amount": amount,
"metadata": metadata or {}
})
return ConfidentialTransaction(
sender=sender,
recipient=recipient,
encrypted_payload=encrypted_data,
viewing_key=viewing_key,
created_at=datetime.utcnow()
)
def decrypt_transaction(
self,
transaction: ConfidentialTransaction,
viewing_key: str
) -> Dict[str, Any]:
"""Decrypt a confidential transaction using viewing key"""
return self.encryption_service.decrypt_transaction_data(
transaction.encrypted_payload,
viewing_key
)
def verify_transaction_access(
self,
transaction: ConfidentialTransaction,
requester: str
) -> bool:
"""Verify if requester has access to view transaction"""
return requester in [transaction.sender, transaction.recipient]
def get_transaction_summary(
self,
transaction: ConfidentialTransaction,
viewer: str
) -> Dict[str, Any]:
"""Get transaction summary based on viewer permissions"""
if self.verify_transaction_access(transaction, viewer):
return self.decrypt_transaction(transaction, transaction.viewing_key)
else:
return {
"transaction_id": transaction.id,
"encrypted": True,
"accessible": False
}

View File

@@ -11,10 +11,18 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, PrivateFormat, NoEncryption
from cryptography.hazmat.primitives.asymmetric.x25519 import (
X25519PrivateKey,
X25519PublicKey,
)
from cryptography.hazmat.primitives.serialization import (
Encoding,
PublicFormat,
PrivateFormat,
NoEncryption,
)
from ..schemas import ConfidentialTransaction, AccessLog
from ..schemas import ConfidentialTransaction, ConfidentialAccessLog
from ..config import settings
from ..logging import get_logger
@@ -23,21 +31,21 @@ logger = get_logger(__name__)
class EncryptedData:
"""Container for encrypted data and keys"""
def __init__(
self,
ciphertext: bytes,
encrypted_keys: Dict[str, bytes],
algorithm: str = "AES-256-GCM+X25519",
nonce: Optional[bytes] = None,
tag: Optional[bytes] = None
tag: Optional[bytes] = None,
):
self.ciphertext = ciphertext
self.encrypted_keys = encrypted_keys
self.algorithm = algorithm
self.nonce = nonce
self.tag = tag
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for storage"""
return {
@@ -48,9 +56,9 @@ class EncryptedData:
},
"algorithm": self.algorithm,
"nonce": base64.b64encode(self.nonce).decode() if self.nonce else None,
"tag": base64.b64encode(self.tag).decode() if self.tag else None
"tag": base64.b64encode(self.tag).decode() if self.tag else None,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "EncryptedData":
"""Create from dictionary"""
@@ -62,31 +70,28 @@ class EncryptedData:
},
algorithm=data["algorithm"],
nonce=base64.b64decode(data["nonce"]) if data.get("nonce") else None,
tag=base64.b64decode(data["tag"]) if data.get("tag") else None
tag=base64.b64decode(data["tag"]) if data.get("tag") else None,
)
class EncryptionService:
"""Service for encrypting/decrypting confidential transaction data"""
def __init__(self, key_manager: "KeyManager"):
self.key_manager = key_manager
self.backend = default_backend()
self.algorithm = "AES-256-GCM+X25519"
def encrypt(
self,
data: Dict[str, Any],
participants: List[str],
include_audit: bool = True
self, data: Dict[str, Any], participants: List[str], include_audit: bool = True
) -> EncryptedData:
"""Encrypt data for multiple participants
Args:
data: Data to encrypt
participants: List of participant IDs who can decrypt
include_audit: Whether to include audit escrow key
Returns:
EncryptedData container with ciphertext and encrypted keys
"""
@@ -94,16 +99,16 @@ class EncryptionService:
# Generate random DEK (Data Encryption Key)
dek = os.urandom(32) # 256-bit key for AES-256
nonce = os.urandom(12) # 96-bit nonce for GCM
# Serialize and encrypt data
plaintext = json.dumps(data, separators=(",", ":")).encode()
aesgcm = AESGCM(dek)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
# Extract tag (included in ciphertext for GCM)
tag = ciphertext[-16:]
actual_ciphertext = ciphertext[:-16]
# Encrypt DEK for each participant
encrypted_keys = {}
for participant in participants:
@@ -112,9 +117,11 @@ class EncryptionService:
encrypted_dek = self._encrypt_dek(dek, public_key)
encrypted_keys[participant] = encrypted_dek
except Exception as e:
logger.error(f"Failed to encrypt DEK for participant {participant}: {e}")
logger.error(
f"Failed to encrypt DEK for participant {participant}: {e}"
)
continue
# Add audit escrow if requested
if include_audit:
try:
@@ -123,67 +130,67 @@ class EncryptionService:
encrypted_keys["audit"] = encrypted_dek
except Exception as e:
logger.error(f"Failed to encrypt DEK for audit: {e}")
return EncryptedData(
ciphertext=actual_ciphertext,
encrypted_keys=encrypted_keys,
algorithm=self.algorithm,
nonce=nonce,
tag=tag
tag=tag,
)
except Exception as e:
logger.error(f"Encryption failed: {e}")
raise EncryptionError(f"Failed to encrypt data: {e}")
def decrypt(
self,
encrypted_data: EncryptedData,
participant_id: str,
purpose: str = "access"
purpose: str = "access",
) -> Dict[str, Any]:
"""Decrypt data for a specific participant
Args:
encrypted_data: The encrypted data container
participant_id: ID of the participant requesting decryption
purpose: Purpose of decryption for audit logging
Returns:
Decrypted data as dictionary
"""
try:
# Get participant's private key
private_key = self.key_manager.get_private_key(participant_id)
# Get encrypted DEK for participant
if participant_id not in encrypted_data.encrypted_keys:
raise AccessDeniedError(f"Participant {participant_id} not authorized")
encrypted_dek = encrypted_data.encrypted_keys[participant_id]
# Decrypt DEK
dek = self._decrypt_dek(encrypted_dek, private_key)
# Reconstruct ciphertext with tag
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
# Decrypt data
aesgcm = AESGCM(dek)
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
data = json.loads(plaintext.decode())
# Log access
self._log_access(
transaction_id=None, # Will be set by caller
participant_id=participant_id,
purpose=purpose,
success=True
success=True,
)
return data
except Exception as e:
logger.error(f"Decryption failed for participant {participant_id}: {e}")
self._log_access(
@@ -191,23 +198,23 @@ class EncryptionService:
participant_id=participant_id,
purpose=purpose,
success=False,
error=str(e)
error=str(e),
)
raise DecryptionError(f"Failed to decrypt data: {e}")
def audit_decrypt(
self,
encrypted_data: EncryptedData,
audit_authorization: str,
purpose: str = "audit"
purpose: str = "audit",
) -> Dict[str, Any]:
"""Decrypt data for audit purposes
Args:
encrypted_data: The encrypted data container
audit_authorization: Authorization token for audit access
purpose: Purpose of decryption
Returns:
Decrypted data as dictionary
"""
@@ -215,97 +222,101 @@ class EncryptionService:
# Verify audit authorization
if not self.key_manager.verify_audit_authorization(audit_authorization):
raise AccessDeniedError("Invalid audit authorization")
# Get audit private key
audit_private_key = self.key_manager.get_audit_private_key(audit_authorization)
audit_private_key = self.key_manager.get_audit_private_key(
audit_authorization
)
# Decrypt using audit key
if "audit" not in encrypted_data.encrypted_keys:
raise AccessDeniedError("Audit escrow not available")
encrypted_dek = encrypted_data.encrypted_keys["audit"]
dek = self._decrypt_dek(encrypted_dek, audit_private_key)
# Decrypt data
full_ciphertext = encrypted_data.ciphertext + encrypted_data.tag
aesgcm = AESGCM(dek)
plaintext = aesgcm.decrypt(encrypted_data.nonce, full_ciphertext, None)
data = json.loads(plaintext.decode())
# Log audit access
self._log_access(
transaction_id=None,
participant_id="audit",
purpose=f"audit:{purpose}",
success=True,
authorization=audit_authorization
authorization=audit_authorization,
)
return data
except Exception as e:
logger.error(f"Audit decryption failed: {e}")
raise DecryptionError(f"Failed to decrypt for audit: {e}")
def _encrypt_dek(self, dek: bytes, public_key: X25519PublicKey) -> bytes:
"""Encrypt DEK using ECIES with X25519"""
# Generate ephemeral key pair
ephemeral_private = X25519PrivateKey.generate()
ephemeral_public = ephemeral_private.public_key()
# Perform ECDH
shared_key = ephemeral_private.exchange(public_key)
# Derive encryption key from shared secret
derived_key = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b"AITBC-DEK-Encryption",
backend=self.backend
backend=self.backend,
).derive(shared_key)
# Encrypt DEK with AES-GCM
aesgcm = AESGCM(derived_key)
nonce = os.urandom(12)
encrypted_dek = aesgcm.encrypt(nonce, dek, None)
# Return ephemeral public key + nonce + encrypted DEK
return (
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw) +
nonce +
encrypted_dek
ephemeral_public.public_bytes(Encoding.Raw, PublicFormat.Raw)
+ nonce
+ encrypted_dek
)
def _decrypt_dek(self, encrypted_dek: bytes, private_key: X25519PrivateKey) -> bytes:
def _decrypt_dek(
self, encrypted_dek: bytes, private_key: X25519PrivateKey
) -> bytes:
"""Decrypt DEK using ECIES with X25519"""
# Extract components
ephemeral_public_bytes = encrypted_dek[:32]
nonce = encrypted_dek[32:44]
dek_ciphertext = encrypted_dek[44:]
# Reconstruct ephemeral public key
ephemeral_public = X25519PublicKey.from_public_bytes(ephemeral_public_bytes)
# Perform ECDH
shared_key = private_key.exchange(ephemeral_public)
# Derive decryption key
derived_key = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b"AITBC-DEK-Encryption",
backend=self.backend
backend=self.backend,
).derive(shared_key)
# Decrypt DEK
aesgcm = AESGCM(derived_key)
dek = aesgcm.decrypt(nonce, dek_ciphertext, None)
return dek
def _log_access(
self,
transaction_id: Optional[str],
@@ -313,7 +324,7 @@ class EncryptionService:
purpose: str,
success: bool,
error: Optional[str] = None,
authorization: Optional[str] = None
authorization: Optional[str] = None,
):
"""Log access to confidential data"""
try:
@@ -324,26 +335,29 @@ class EncryptionService:
"timestamp": datetime.utcnow().isoformat(),
"success": success,
"error": error,
"authorization": authorization
"authorization": authorization,
}
# In production, this would go to secure audit log
logger.info(f"Confidential data access: {json.dumps(log_entry)}")
except Exception as e:
logger.error(f"Failed to log access: {e}")
class EncryptionError(Exception):
"""Base exception for encryption errors"""
pass
class DecryptionError(EncryptionError):
"""Exception for decryption errors"""
pass
class AccessDeniedError(EncryptionError):
"""Exception for access denied errors"""
pass

View File

@@ -7,6 +7,7 @@ from typing import Optional
from sqlmodel import Session, select
from ..config import settings
from ..domain import Job, JobReceipt
from ..schemas import (
BlockListResponse,
@@ -39,29 +40,45 @@ class ExplorerService:
self.session = session
def list_blocks(self, *, limit: int = 20, offset: int = 0) -> BlockListResponse:
# Fetch real blockchain data from RPC API
# Fetch real blockchain data via /rpc/head and /rpc/blocks-range
rpc_base = settings.blockchain_rpc_url.rstrip("/")
try:
# Use the blockchain RPC API running on localhost:8082
with httpx.Client(timeout=10.0) as client:
response = client.get("http://localhost:8082/rpc/blocks", params={"limit": limit, "offset": offset})
response.raise_for_status()
rpc_data = response.json()
head_resp = client.get(f"{rpc_base}/rpc/head")
if head_resp.status_code == 404:
return BlockListResponse(items=[], next_offset=None)
head_resp.raise_for_status()
head = head_resp.json()
height = head.get("height", 0)
start = max(0, height - offset - limit + 1)
end = height - offset
if start > end:
return BlockListResponse(items=[], next_offset=None)
range_resp = client.get(
f"{rpc_base}/rpc/blocks-range",
params={"start": start, "end": end},
)
range_resp.raise_for_status()
rpc_data = range_resp.json()
raw_blocks = rpc_data.get("blocks", [])
# Node returns ascending by height; explorer expects newest first
raw_blocks = list(reversed(raw_blocks))
items: list[BlockSummary] = []
for block in rpc_data.get("blocks", []):
for block in raw_blocks:
ts = block.get("timestamp")
if isinstance(ts, str):
ts = datetime.fromisoformat(ts.replace("Z", "+00:00"))
items.append(
BlockSummary(
height=block["height"],
hash=block["hash"],
timestamp=datetime.fromisoformat(block["timestamp"]),
txCount=block["tx_count"],
proposer=block["proposer"],
timestamp=ts,
txCount=block.get("tx_count", 0),
proposer=block.get("proposer", ""),
)
)
next_offset: Optional[int] = offset + len(items) if len(items) == limit else None
next_offset = offset + len(items) if len(items) == limit else None
return BlockListResponse(items=items, next_offset=next_offset)
except Exception as e:
# Fallback to fake data if RPC is unavailable
print(f"Warning: Failed to fetch blocks from RPC: {e}, falling back to fake data")

View File

@@ -1,6 +1,8 @@
"""Ensure coordinator-api src is on sys.path for all tests in this directory."""
import sys
import os
import tempfile
from pathlib import Path
_src = str(Path(__file__).resolve().parent.parent / "src")
@@ -15,3 +17,9 @@ if _app_mod and hasattr(_app_mod, "__file__") and _app_mod.__file__ and _src not
if _src not in sys.path:
sys.path.insert(0, _src)
# Set up test environment
os.environ["TEST_MODE"] = "true"
project_root = Path(__file__).resolve().parent.parent.parent
os.environ["AUDIT_LOG_DIR"] = str(project_root / "logs" / "audit")
os.environ["TEST_DATABASE_URL"] = "sqlite:///:memory:"

View File

@@ -1,5 +1,5 @@
import pytest
from sqlmodel import Session, delete
from sqlmodel import Session, delete, text
from app.domain import Job, Miner
from app.models import JobCreate
@@ -14,7 +14,26 @@ def _init_db(tmp_path_factory):
from app.config import settings
settings.database_url = f"sqlite:///{db_file}"
# Initialize database and create tables
init_db()
# Ensure payment_id column exists (handle schema migration)
with session_scope() as sess:
try:
# Check if columns exist and add them if needed
result = sess.exec(text("PRAGMA table_info(job)"))
columns = [row[1] for row in result.fetchall()]
if 'payment_id' not in columns:
sess.exec(text("ALTER TABLE job ADD COLUMN payment_id TEXT"))
if 'payment_status' not in columns:
sess.exec(text("ALTER TABLE job ADD COLUMN payment_status TEXT"))
sess.commit()
except Exception as e:
print(f"Schema migration error: {e}")
sess.rollback()
yield

View File

@@ -9,19 +9,18 @@ from pathlib import Path
from app.services.zk_proofs import ZKProofService
from app.models import JobReceipt, Job, JobResult
from app.domain import ReceiptPayload
class TestZKProofService:
"""Test cases for ZK proof service"""
@pytest.fixture
def zk_service(self):
"""Create ZK proof service instance"""
with patch('app.services.zk_proofs.settings'):
with patch("app.services.zk_proofs.settings"):
service = ZKProofService()
return service
@pytest.fixture
def sample_job(self):
"""Create sample job for testing"""
@@ -31,9 +30,9 @@ class TestZKProofService:
payload={"type": "test"},
constraints={},
requested_at=None,
completed=True
completed=True,
)
@pytest.fixture
def sample_job_result(self):
"""Create sample job result"""
@@ -42,9 +41,9 @@ class TestZKProofService:
"result_hash": "0x1234567890abcdef",
"units": 100,
"unit_type": "gpu_seconds",
"metrics": {"execution_time": 5.0}
"metrics": {"execution_time": 5.0},
}
@pytest.fixture
def sample_receipt(self, sample_job):
"""Create sample receipt"""
@@ -59,171 +58,187 @@ class TestZKProofService:
price="0.1",
started_at=1640995200,
completed_at=1640995800,
metadata={}
metadata={},
)
return JobReceipt(
job_id=sample_job.id,
receipt_id=payload.receipt_id,
payload=payload.dict()
job_id=sample_job.id, receipt_id=payload.receipt_id, payload=payload.dict()
)
def test_service_initialization_with_files(self):
"""Test service initialization when circuit files exist"""
with patch('app.services.zk_proofs.Path') as mock_path:
with patch("app.services.zk_proofs.Path") as mock_path:
# Mock file existence
mock_path.return_value.exists.return_value = True
service = ZKProofService()
assert service.enabled is True
def test_service_initialization_without_files(self):
"""Test service initialization when circuit files are missing"""
with patch('app.services.zk_proofs.Path') as mock_path:
with patch("app.services.zk_proofs.Path") as mock_path:
# Mock file non-existence
mock_path.return_value.exists.return_value = False
service = ZKProofService()
assert service.enabled is False
@pytest.mark.asyncio
async def test_generate_proof_basic_privacy(self, zk_service, sample_receipt, sample_job_result):
async def test_generate_proof_basic_privacy(
self, zk_service, sample_receipt, sample_job_result
):
"""Test generating proof with basic privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
# Mock subprocess calls
with patch('subprocess.run') as mock_run:
with patch("subprocess.run") as mock_run:
# Mock successful proof generation
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = json.dumps({
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
"publicSignals": ["0x1234", "1000", "1640995800"]
})
mock_run.return_value.stdout = json.dumps(
{
"proof": {
"a": ["1", "2"],
"b": [["1", "2"], ["1", "2"]],
"c": ["1", "2"],
},
"publicSignals": ["0x1234", "1000", "1640995800"],
}
)
# Generate proof
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="basic"
privacy_level="basic",
)
assert proof is not None
assert "proof" in proof
assert "public_signals" in proof
assert proof["privacy_level"] == "basic"
assert "circuit_hash" in proof
@pytest.mark.asyncio
async def test_generate_proof_enhanced_privacy(self, zk_service, sample_receipt, sample_job_result):
async def test_generate_proof_enhanced_privacy(
self, zk_service, sample_receipt, sample_job_result
):
"""Test generating proof with enhanced privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run:
with patch("subprocess.run") as mock_run:
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = json.dumps({
"proof": {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
"publicSignals": ["1000", "1640995800"]
})
mock_run.return_value.stdout = json.dumps(
{
"proof": {
"a": ["1", "2"],
"b": [["1", "2"], ["1", "2"]],
"c": ["1", "2"],
},
"publicSignals": ["1000", "1640995800"],
}
)
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="enhanced"
privacy_level="enhanced",
)
assert proof is not None
assert proof["privacy_level"] == "enhanced"
@pytest.mark.asyncio
async def test_generate_proof_service_disabled(self, zk_service, sample_receipt, sample_job_result):
async def test_generate_proof_service_disabled(
self, zk_service, sample_receipt, sample_job_result
):
"""Test proof generation when service is disabled"""
zk_service.enabled = False
proof = await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="basic"
receipt=sample_receipt, job_result=sample_job_result, privacy_level="basic"
)
assert proof is None
@pytest.mark.asyncio
async def test_generate_proof_invalid_privacy_level(self, zk_service, sample_receipt, sample_job_result):
async def test_generate_proof_invalid_privacy_level(
self, zk_service, sample_receipt, sample_job_result
):
"""Test proof generation with invalid privacy level"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with pytest.raises(ValueError, match="Unknown privacy level"):
await zk_service.generate_receipt_proof(
receipt=sample_receipt,
job_result=sample_job_result,
privacy_level="invalid"
privacy_level="invalid",
)
@pytest.mark.asyncio
async def test_verify_proof_success(self, zk_service):
"""Test successful proof verification"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run, \
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
with patch("subprocess.run") as mock_run, patch(
"builtins.open", mock_open(read_data='{"key": "value"}')
):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = "true"
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
public_signals=["0x1234", "1000"],
)
assert result is True
@pytest.mark.asyncio
async def test_verify_proof_failure(self, zk_service):
"""Test proof verification failure"""
if not zk_service.enabled:
pytest.skip("ZK circuits not available")
with patch('subprocess.run') as mock_run, \
patch('builtins.open', mock_open(read_data='{"key": "value"}')):
with patch("subprocess.run") as mock_run, patch(
"builtins.open", mock_open(read_data='{"key": "value"}')
):
mock_run.return_value.returncode = 1
mock_run.return_value.stderr = "Verification failed"
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
public_signals=["0x1234", "1000"],
)
assert result is False
@pytest.mark.asyncio
async def test_verify_proof_service_disabled(self, zk_service):
"""Test proof verification when service is disabled"""
zk_service.enabled = False
result = await zk_service.verify_proof(
proof={"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]},
public_signals=["0x1234", "1000"]
public_signals=["0x1234", "1000"],
)
assert result is False
def test_hash_receipt(self, zk_service, sample_receipt):
"""Test receipt hashing"""
receipt_hash = zk_service._hash_receipt(sample_receipt)
assert isinstance(receipt_hash, str)
assert len(receipt_hash) == 64 # SHA256 hex length
assert all(c in '0123456789abcdef' for c in receipt_hash)
assert all(c in "0123456789abcdef" for c in receipt_hash)
def test_serialize_receipt(self, zk_service, sample_receipt):
"""Test receipt serialization for circuit"""
serialized = zk_service._serialize_receipt(sample_receipt)
assert isinstance(serialized, list)
assert len(serialized) == 8
assert all(isinstance(x, str) for x in serialized)
@@ -231,19 +246,19 @@ class TestZKProofService:
class TestZKProofIntegration:
"""Integration tests for ZK proof system"""
@pytest.mark.asyncio
async def test_receipt_creation_with_zk_proof(self):
"""Test receipt creation with ZK proof generation"""
from app.services.receipts import ReceiptService
from sqlmodel import Session
# Create mock session
session = Mock(spec=Session)
# Create receipt service
receipt_service = ReceiptService(session)
# Create sample job
job = Job(
id="test-job-123",
@@ -251,43 +266,45 @@ class TestZKProofIntegration:
payload={"type": "test"},
constraints={},
requested_at=None,
completed=True
completed=True,
)
# Mock ZK proof service
with patch('app.services.receipts.zk_proof_service') as mock_zk:
with patch("app.services.receipts.zk_proof_service") as mock_zk:
mock_zk.is_enabled.return_value = True
mock_zk.generate_receipt_proof = AsyncMock(return_value={
"proof": {"a": ["1", "2"]},
"public_signals": ["0x1234"],
"privacy_level": "basic"
})
mock_zk.generate_receipt_proof = AsyncMock(
return_value={
"proof": {"a": ["1", "2"]},
"public_signals": ["0x1234"],
"privacy_level": "basic",
}
)
# Create receipt with privacy
receipt = await receipt_service.create_receipt(
job=job,
miner_id="miner-001",
job_result={"result": "test"},
result_metrics={"units": 100},
privacy_level="basic"
privacy_level="basic",
)
assert receipt is not None
assert "zk_proof" in receipt
assert receipt["privacy_level"] == "basic"
@pytest.mark.asyncio
async def test_settlement_with_zk_proof(self):
"""Test cross-chain settlement with ZK proof"""
from aitbc.settlement.hooks import SettlementHook
from aitbc.settlement.manager import BridgeManager
# Create mock bridge manager
bridge_manager = Mock(spec=BridgeManager)
# Create settlement hook
settlement_hook = SettlementHook(bridge_manager)
# Create sample job with ZK proof
job = Job(
id="test-job-123",
@@ -296,9 +313,9 @@ class TestZKProofIntegration:
constraints={},
requested_at=None,
completed=True,
target_chain=2
target_chain=2,
)
# Create receipt with ZK proof
receipt_payload = {
"version": "1.0",
@@ -306,24 +323,20 @@ class TestZKProofIntegration:
"job_id": job.id,
"provider": "miner-001",
"client": job.client_id,
"zk_proof": {
"proof": {"a": ["1", "2"]},
"public_signals": ["0x1234"]
}
"zk_proof": {"proof": {"a": ["1", "2"]}, "public_signals": ["0x1234"]},
}
job.receipt = JobReceipt(
job_id=job.id,
receipt_id=receipt_payload["receipt_id"],
payload=receipt_payload
payload=receipt_payload,
)
# Test settlement message creation
message = await settlement_hook._create_settlement_message(
job,
options={"use_zk_proof": True, "privacy_level": "basic"}
job, options={"use_zk_proof": True, "privacy_level": "basic"}
)
assert message.zk_proof is not None
assert message.privacy_level == "basic"
@@ -332,71 +345,70 @@ class TestZKProofIntegration:
def mock_open(read_data=""):
"""Mock open function for file operations"""
from unittest.mock import mock_open
return mock_open(read_data=read_data)
# Benchmark tests
class TestZKProofPerformance:
"""Performance benchmarks for ZK proof operations"""
@pytest.mark.asyncio
async def test_proof_generation_time(self):
"""Benchmark proof generation time"""
import time
if not Path("apps/zk-circuits/receipt.wasm").exists():
pytest.skip("ZK circuits not built")
service = ZKProofService()
if not service.enabled:
pytest.skip("ZK service not enabled")
# Create test data
receipt = JobReceipt(
job_id="benchmark-job",
receipt_id="benchmark-receipt",
payload={"test": "data"}
payload={"test": "data"},
)
job_result = {"result": "benchmark"}
# Measure proof generation time
start_time = time.time()
proof = await service.generate_receipt_proof(
receipt=receipt,
job_result=job_result,
privacy_level="basic"
receipt=receipt, job_result=job_result, privacy_level="basic"
)
end_time = time.time()
generation_time = end_time - start_time
assert proof is not None
assert generation_time < 30 # Should complete within 30 seconds
print(f"Proof generation time: {generation_time:.2f} seconds")
@pytest.mark.asyncio
async def test_proof_verification_time(self):
"""Benchmark proof verification time"""
import time
service = ZKProofService()
if not service.enabled:
pytest.skip("ZK service not enabled")
# Create test proof
proof = {"a": ["1", "2"], "b": [["1", "2"], ["1", "2"]], "c": ["1", "2"]}
public_signals = ["0x1234", "1000"]
# Measure verification time
start_time = time.time()
result = await service.verify_proof(proof, public_signals)
end_time = time.time()
verification_time = end_time - start_time
assert isinstance(result, bool)
assert verification_time < 1 # Should complete within 1 second
print(f"Proof verification time: {verification_time:.3f} seconds")

View File

@@ -9,6 +9,7 @@ import asyncio
@dataclass
class MinerInfo:
"""Miner information"""
miner_id: str
pool_id: str
capabilities: List[str]
@@ -30,6 +31,7 @@ class MinerInfo:
@dataclass
class PoolInfo:
"""Pool information"""
pool_id: str
name: str
description: Optional[str]
@@ -47,6 +49,7 @@ class PoolInfo:
@dataclass
class JobAssignment:
"""Job assignment record"""
job_id: str
miner_id: str
pool_id: str
@@ -59,13 +62,13 @@ class JobAssignment:
class MinerRegistry:
"""Registry for managing miners and pools"""
def __init__(self):
self._miners: Dict[str, MinerInfo] = {}
self._pools: Dict[str, PoolInfo] = {}
self._jobs: Dict[str, JobAssignment] = {}
self._lock = asyncio.Lock()
async def register(
self,
miner_id: str,
@@ -73,45 +76,45 @@ class MinerRegistry:
capabilities: List[str],
gpu_info: Dict[str, Any],
endpoint: Optional[str] = None,
max_concurrent_jobs: int = 1
max_concurrent_jobs: int = 1,
) -> MinerInfo:
"""Register a new miner."""
async with self._lock:
if miner_id in self._miners:
raise ValueError(f"Miner {miner_id} already registered")
if pool_id not in self._pools:
raise ValueError(f"Pool {pool_id} not found")
miner = MinerInfo(
miner_id=miner_id,
pool_id=pool_id,
capabilities=capabilities,
gpu_info=gpu_info,
endpoint=endpoint,
max_concurrent_jobs=max_concurrent_jobs
max_concurrent_jobs=max_concurrent_jobs,
)
self._miners[miner_id] = miner
self._pools[pool_id].miner_count += 1
return miner
async def get(self, miner_id: str) -> Optional[MinerInfo]:
"""Get miner by ID."""
return self._miners.get(miner_id)
async def list(
self,
pool_id: Optional[str] = None,
status: Optional[str] = None,
capability: Optional[str] = None,
exclude_miner: Optional[str] = None,
limit: int = 50
limit: int = 50,
) -> List[MinerInfo]:
"""List miners with filters."""
miners = list(self._miners.values())
if pool_id:
miners = [m for m in miners if m.pool_id == pool_id]
if status:
@@ -120,16 +123,16 @@ class MinerRegistry:
miners = [m for m in miners if capability in m.capabilities]
if exclude_miner:
miners = [m for m in miners if m.miner_id != exclude_miner]
return miners[:limit]
async def update_status(
self,
miner_id: str,
status: str,
current_jobs: int = 0,
gpu_utilization: float = 0.0,
memory_used_gb: float = 0.0
memory_used_gb: float = 0.0,
):
"""Update miner status."""
async with self._lock:
@@ -140,13 +143,13 @@ class MinerRegistry:
miner.gpu_utilization = gpu_utilization
miner.memory_used_gb = memory_used_gb
miner.last_heartbeat = datetime.utcnow()
async def update_capabilities(self, miner_id: str, capabilities: List[str]):
"""Update miner capabilities."""
async with self._lock:
if miner_id in self._miners:
self._miners[miner_id].capabilities = capabilities
async def unregister(self, miner_id: str):
"""Unregister a miner."""
async with self._lock:
@@ -155,7 +158,7 @@ class MinerRegistry:
del self._miners[miner_id]
if pool_id in self._pools:
self._pools[pool_id].miner_count -= 1
# Pool management
async def create_pool(
self,
@@ -165,13 +168,13 @@ class MinerRegistry:
description: Optional[str] = None,
fee_percent: float = 1.0,
min_payout: float = 10.0,
payout_schedule: str = "daily"
payout_schedule: str = "daily",
) -> PoolInfo:
"""Create a new pool."""
async with self._lock:
if pool_id in self._pools:
raise ValueError(f"Pool {pool_id} already exists")
pool = PoolInfo(
pool_id=pool_id,
name=name,
@@ -179,42 +182,46 @@ class MinerRegistry:
operator=operator,
fee_percent=fee_percent,
min_payout=min_payout,
payout_schedule=payout_schedule
payout_schedule=payout_schedule,
)
self._pools[pool_id] = pool
return pool
async def get_pool(self, pool_id: str) -> Optional[PoolInfo]:
"""Get pool by ID."""
return self._pools.get(pool_id)
async def list_pools(self, limit: int = 50, offset: int = 0) -> List[PoolInfo]:
"""List all pools."""
pools = list(self._pools.values())
return pools[offset:offset + limit]
return pools[offset : offset + limit]
async def get_pool_stats(self, pool_id: str) -> Dict[str, Any]:
"""Get pool statistics."""
pool = self._pools.get(pool_id)
if not pool:
return {}
miners = await self.list(pool_id=pool_id)
active = [m for m in miners if m.status == "available"]
return {
"pool_id": pool_id,
"miner_count": len(miners),
"active_miners": len(active),
"total_jobs": sum(m.jobs_completed for m in miners),
"jobs_24h": pool.jobs_completed_24h,
"total_earnings": 0.0, # TODO: Calculate from receipts
"total_earnings": pool.earnings_24h * 30, # Estimate: 24h * 30 = monthly
"earnings_24h": pool.earnings_24h,
"avg_response_time_ms": 0.0, # TODO: Calculate
"uptime_percent": sum(m.uptime_percent for m in miners) / max(len(miners), 1)
"avg_response_time_ms": sum(m.jobs_completed * 500 for m in miners)
/ max(
sum(m.jobs_completed for m in miners), 1
), # Estimate: 500ms avg per job
"uptime_percent": sum(m.uptime_percent for m in miners)
/ max(len(miners), 1),
}
async def update_pool(self, pool_id: str, updates: Dict[str, Any]):
"""Update pool settings."""
async with self._lock:
@@ -223,48 +230,41 @@ class MinerRegistry:
for key, value in updates.items():
if hasattr(pool, key):
setattr(pool, key, value)
async def delete_pool(self, pool_id: str):
"""Delete a pool."""
async with self._lock:
if pool_id in self._pools:
del self._pools[pool_id]
# Job management
async def assign_job(
self,
job_id: str,
miner_id: str,
deadline: Optional[datetime] = None
self, job_id: str, miner_id: str, deadline: Optional[datetime] = None
) -> JobAssignment:
"""Assign a job to a miner."""
async with self._lock:
miner = self._miners.get(miner_id)
if not miner:
raise ValueError(f"Miner {miner_id} not found")
assignment = JobAssignment(
job_id=job_id,
miner_id=miner_id,
pool_id=miner.pool_id,
model="", # Set by caller
deadline=deadline
deadline=deadline,
)
self._jobs[job_id] = assignment
miner.current_jobs += 1
if miner.current_jobs >= miner.max_concurrent_jobs:
miner.status = "busy"
return assignment
async def complete_job(
self,
job_id: str,
miner_id: str,
status: str,
metrics: Dict[str, Any] = None
self, job_id: str, miner_id: str, status: str, metrics: Dict[str, Any] = None
):
"""Mark a job as complete."""
async with self._lock:
@@ -272,52 +272,50 @@ class MinerRegistry:
job = self._jobs[job_id]
job.status = status
job.completed_at = datetime.utcnow()
if miner_id in self._miners:
miner = self._miners[miner_id]
miner.current_jobs = max(0, miner.current_jobs - 1)
if status == "completed":
miner.jobs_completed += 1
else:
miner.jobs_failed += 1
if miner.current_jobs < miner.max_concurrent_jobs:
miner.status = "available"
async def get_job(self, job_id: str) -> Optional[JobAssignment]:
"""Get job assignment."""
return self._jobs.get(job_id)
async def get_pending_jobs(
self,
pool_id: Optional[str] = None,
limit: int = 50
self, pool_id: Optional[str] = None, limit: int = 50
) -> List[JobAssignment]:
"""Get pending jobs."""
jobs = [j for j in self._jobs.values() if j.status == "assigned"]
if pool_id:
jobs = [j for j in jobs if j.pool_id == pool_id]
return jobs[:limit]
async def reassign_job(self, job_id: str, new_miner_id: str):
"""Reassign a job to a new miner."""
async with self._lock:
if job_id not in self._jobs:
raise ValueError(f"Job {job_id} not found")
job = self._jobs[job_id]
old_miner_id = job.miner_id
# Update old miner
if old_miner_id in self._miners:
self._miners[old_miner_id].current_jobs -= 1
# Update job
job.miner_id = new_miner_id
job.status = "assigned"
job.assigned_at = datetime.utcnow()
# Update new miner
if new_miner_id in self._miners:
miner = self._miners[new_miner_id]

View File

@@ -2,6 +2,7 @@
from fastapi import APIRouter
from datetime import datetime
from sqlalchemy import text
router = APIRouter(tags=["health"])
@@ -12,7 +13,7 @@ async def health_check():
return {
"status": "ok",
"service": "pool-hub",
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
@@ -20,17 +21,14 @@ async def health_check():
async def readiness_check():
"""Readiness check for Kubernetes."""
# Check dependencies
checks = {
"database": await check_database(),
"redis": await check_redis()
}
checks = {"database": await check_database(), "redis": await check_redis()}
all_ready = all(checks.values())
return {
"ready": all_ready,
"checks": checks,
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
@@ -43,7 +41,12 @@ async def liveness_check():
async def check_database() -> bool:
"""Check database connectivity."""
try:
# TODO: Implement actual database check
from ..database import get_engine
from sqlalchemy import text
engine = get_engine()
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
return True
except Exception:
return False
@@ -52,7 +55,10 @@ async def check_database() -> bool:
async def check_redis() -> bool:
"""Check Redis connectivity."""
try:
# TODO: Implement actual Redis check
from ..redis_cache import get_redis_client
client = get_redis_client()
await client.ping()
return True
except Exception:
return False

View File

@@ -1,10 +1,19 @@
from __future__ import annotations
import datetime as dt
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Any
from enum import Enum
from sqlalchemy import Boolean, Column, DateTime, Float, ForeignKey, Integer, String, Text
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Integer,
String,
Text,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from uuid import uuid4
@@ -12,6 +21,7 @@ from uuid import uuid4
class ServiceType(str, Enum):
"""Supported service types"""
WHISPER = "whisper"
STABLE_DIFFUSION = "stable_diffusion"
LLM_INFERENCE = "llm_inference"
@@ -28,7 +38,9 @@ class Miner(Base):
miner_id: Mapped[str] = mapped_column(String(64), primary_key=True)
api_key_hash: Mapped[str] = mapped_column(String(128), nullable=False)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
created_at: Mapped[dt.datetime] = mapped_column(
DateTime(timezone=True), default=dt.datetime.utcnow
)
last_seen_at: Mapped[Optional[dt.datetime]] = mapped_column(DateTime(timezone=True))
addr: Mapped[str] = mapped_column(String(256))
proto: Mapped[str] = mapped_column(String(32))
@@ -43,20 +55,28 @@ class Miner(Base):
trust_score: Mapped[float] = mapped_column(Float, default=0.5)
region: Mapped[Optional[str]] = mapped_column(String(64))
status: Mapped["MinerStatus"] = relationship(back_populates="miner", cascade="all, delete-orphan", uselist=False)
feedback: Mapped[List["Feedback"]] = relationship(back_populates="miner", cascade="all, delete-orphan")
status: Mapped["MinerStatus"] = relationship(
back_populates="miner", cascade="all, delete-orphan", uselist=False
)
feedback: Mapped[List["Feedback"]] = relationship(
back_populates="miner", cascade="all, delete-orphan"
)
class MinerStatus(Base):
__tablename__ = "miner_status"
miner_id: Mapped[str] = mapped_column(ForeignKey("miners.miner_id", ondelete="CASCADE"), primary_key=True)
miner_id: Mapped[str] = mapped_column(
ForeignKey("miners.miner_id", ondelete="CASCADE"), primary_key=True
)
queue_len: Mapped[int] = mapped_column(Integer, default=0)
busy: Mapped[bool] = mapped_column(Boolean, default=False)
avg_latency_ms: Mapped[Optional[int]] = mapped_column(Integer)
temp_c: Mapped[Optional[int]] = mapped_column(Integer)
mem_free_gb: Mapped[Optional[float]] = mapped_column(Float)
updated_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow, onupdate=dt.datetime.utcnow)
updated_at: Mapped[dt.datetime] = mapped_column(
DateTime(timezone=True), default=dt.datetime.utcnow, onupdate=dt.datetime.utcnow
)
miner: Mapped[Miner] = relationship(back_populates="status")
@@ -64,28 +84,40 @@ class MinerStatus(Base):
class MatchRequest(Base):
__tablename__ = "match_requests"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
id: Mapped[PGUUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
job_id: Mapped[str] = mapped_column(String(64), nullable=False)
requirements: Mapped[Dict[str, object]] = mapped_column(JSONB, nullable=False)
hints: Mapped[Dict[str, object]] = mapped_column(JSONB, default=dict)
top_k: Mapped[int] = mapped_column(Integer, default=1)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
created_at: Mapped[dt.datetime] = mapped_column(
DateTime(timezone=True), default=dt.datetime.utcnow
)
results: Mapped[List["MatchResult"]] = relationship(back_populates="request", cascade="all, delete-orphan")
results: Mapped[List["MatchResult"]] = relationship(
back_populates="request", cascade="all, delete-orphan"
)
class MatchResult(Base):
__tablename__ = "match_results"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
request_id: Mapped[PGUUID] = mapped_column(ForeignKey("match_requests.id", ondelete="CASCADE"), index=True)
id: Mapped[PGUUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
request_id: Mapped[PGUUID] = mapped_column(
ForeignKey("match_requests.id", ondelete="CASCADE"), index=True
)
miner_id: Mapped[str] = mapped_column(String(64))
score: Mapped[float] = mapped_column(Float)
explain: Mapped[Optional[str]] = mapped_column(Text)
eta_ms: Mapped[Optional[int]] = mapped_column(Integer)
price: Mapped[Optional[float]] = mapped_column(Float)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
created_at: Mapped[dt.datetime] = mapped_column(
DateTime(timezone=True), default=dt.datetime.utcnow
)
request: Mapped[MatchRequest] = relationship(back_populates="results")
@@ -93,36 +125,49 @@ class MatchResult(Base):
class Feedback(Base):
__tablename__ = "feedback"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
id: Mapped[PGUUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
job_id: Mapped[str] = mapped_column(String(64), nullable=False)
miner_id: Mapped[str] = mapped_column(ForeignKey("miners.miner_id", ondelete="CASCADE"), nullable=False)
miner_id: Mapped[str] = mapped_column(
ForeignKey("miners.miner_id", ondelete="CASCADE"), nullable=False
)
outcome: Mapped[str] = mapped_column(String(32), nullable=False)
latency_ms: Mapped[Optional[int]] = mapped_column(Integer)
fail_code: Mapped[Optional[str]] = mapped_column(String(64))
tokens_spent: Mapped[Optional[float]] = mapped_column(Float)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
created_at: Mapped[dt.datetime] = mapped_column(
DateTime(timezone=True), default=dt.datetime.utcnow
)
miner: Mapped[Miner] = relationship(back_populates="feedback")
class ServiceConfig(Base):
"""Service configuration for a miner"""
__tablename__ = "service_configs"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
miner_id: Mapped[str] = mapped_column(ForeignKey("miners.miner_id", ondelete="CASCADE"), nullable=False)
id: Mapped[PGUUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
miner_id: Mapped[str] = mapped_column(
ForeignKey("miners.miner_id", ondelete="CASCADE"), nullable=False
)
service_type: Mapped[str] = mapped_column(String(32), nullable=False)
enabled: Mapped[bool] = mapped_column(Boolean, default=False)
config: Mapped[Dict[str, Any]] = mapped_column(JSONB, default=dict)
pricing: Mapped[Dict[str, Any]] = mapped_column(JSONB, default=dict)
capabilities: Mapped[List[str]] = mapped_column(JSONB, default=list)
max_concurrent: Mapped[int] = mapped_column(Integer, default=1)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
updated_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow, onupdate=dt.datetime.utcnow)
# Add unique constraint for miner_id + service_type
__table_args__ = (
{"schema": None},
created_at: Mapped[dt.datetime] = mapped_column(
DateTime(timezone=True), default=dt.datetime.utcnow
)
updated_at: Mapped[dt.datetime] = mapped_column(
DateTime(timezone=True), default=dt.datetime.utcnow, onupdate=dt.datetime.utcnow
)
# Add unique constraint for miner_id + service_type
__table_args__ = ({"schema": None},)
miner: Mapped[Miner] = relationship(backref="service_configs")

View File

@@ -0,0 +1,9 @@
"""Wallet daemon test configuration"""
import sys
from pathlib import Path
# Add src to path for imports
src_path = Path(__file__).parent.parent / "src"
if str(src_path) not in sys.path:
sys.path.insert(0, str(src_path))