docs: update README with comprehensive test results, CLI documentation, and enhanced feature descriptions

- Update key capabilities to include GPU marketplace, payments, billing, and governance
- Expand CLI section from basic examples to 12 command groups with 90+ subcommands
- Add detailed test results table showing 208 passing tests across 6 test suites
- Update documentation links to reference new CLI reference and coordinator API docs
- Revise test commands to reflect actual test structure (
This commit is contained in:
oib
2026-02-12 20:58:21 +01:00
parent 5120861e17
commit 65b63de56f
47 changed files with 5622 additions and 1148 deletions

View File

@@ -1,41 +1,138 @@
from __future__ import annotations
import time
from collections import defaultdict
from contextlib import asynccontextmanager
from fastapi import APIRouter, FastAPI
from fastapi.responses import PlainTextResponse
from fastapi import APIRouter, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware
from .config import settings
from .database import init_db
from .gossip import create_backend, gossip_broker
from .logger import get_logger
from .mempool import init_mempool
from .metrics import metrics_registry
from .rpc.router import router as rpc_router
from .rpc.websocket import router as websocket_router
_app_logger = get_logger("aitbc_chain.app")
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple in-memory rate limiter per client IP."""
def __init__(self, app, max_requests: int = 100, window_seconds: int = 60):
super().__init__(app)
self._max_requests = max_requests
self._window = window_seconds
self._requests: dict[str, list[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host if request.client else "unknown"
now = time.time()
# Clean old entries
self._requests[client_ip] = [
t for t in self._requests[client_ip] if now - t < self._window
]
if len(self._requests[client_ip]) >= self._max_requests:
metrics_registry.increment("rpc_rate_limited_total")
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={"Retry-After": str(self._window)},
)
self._requests[client_ip].append(now)
return await call_next(request)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Log all requests with timing and error details."""
async def dispatch(self, request: Request, call_next):
start = time.perf_counter()
method = request.method
path = request.url.path
try:
response = await call_next(request)
duration = time.perf_counter() - start
metrics_registry.observe("rpc_request_duration_seconds", duration)
metrics_registry.increment("rpc_requests_total")
if response.status_code >= 500:
metrics_registry.increment("rpc_server_errors_total")
_app_logger.error("Server error", extra={
"method": method, "path": path,
"status": response.status_code, "duration_ms": round(duration * 1000, 2),
})
elif response.status_code >= 400:
metrics_registry.increment("rpc_client_errors_total")
return response
except Exception as exc:
duration = time.perf_counter() - start
metrics_registry.increment("rpc_unhandled_errors_total")
_app_logger.exception("Unhandled error in request", extra={
"method": method, "path": path, "error": str(exc),
"duration_ms": round(duration * 1000, 2),
})
return JSONResponse(
status_code=503,
content={"detail": "Internal server error"},
)
@asynccontextmanager
async def lifespan(app: FastAPI):
init_db()
init_mempool(
backend=settings.mempool_backend,
db_path=str(settings.db_path.parent / "mempool.db"),
max_size=settings.mempool_max_size,
min_fee=settings.min_fee,
)
backend = create_backend(
settings.gossip_backend,
broadcast_url=settings.gossip_broadcast_url,
)
await gossip_broker.set_backend(backend)
_app_logger.info("Blockchain node started", extra={"chain_id": settings.chain_id})
try:
yield
finally:
await gossip_broker.shutdown()
_app_logger.info("Blockchain node stopped")
def create_app() -> FastAPI:
app = FastAPI(title="AITBC Blockchain Node", version="0.1.0", lifespan=lifespan)
# Middleware (applied in reverse order)
app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(RateLimitMiddleware, max_requests=200, window_seconds=60)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
app.include_router(rpc_router, prefix="/rpc", tags=["rpc"])
app.include_router(websocket_router, prefix="/rpc")
metrics_router = APIRouter()
@metrics_router.get("/metrics", response_class=PlainTextResponse, tags=["metrics"], summary="Prometheus metrics")
async def metrics() -> str:
return metrics_registry.render_prometheus()
@metrics_router.get("/health", tags=["health"], summary="Health check")
async def health() -> dict:
return {
"status": "ok",
"chain_id": settings.chain_id,
"proposer_id": settings.proposer_id,
}
app.include_router(metrics_router)
return app

View File

@@ -26,6 +26,25 @@ class ChainSettings(BaseSettings):
block_time_seconds: int = 2
# Block production limits
max_block_size_bytes: int = 1_000_000 # 1 MB
max_txs_per_block: int = 500
min_fee: int = 0 # Minimum fee to accept into mempool
# Mempool settings
mempool_backend: str = "memory" # "memory" or "database"
mempool_max_size: int = 10_000
mempool_eviction_interval: int = 60 # seconds
# Circuit breaker
circuit_breaker_threshold: int = 5 # failures before opening
circuit_breaker_timeout: int = 30 # seconds before half-open
# Sync settings
trusted_proposers: str = "" # comma-separated list of trusted proposer IDs
max_reorg_depth: int = 10 # max blocks to reorg on conflict
sync_validate_signatures: bool = True # validate proposer signatures on import
gossip_backend: str = "memory"
gossip_broadcast_url: Optional[str] = None

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from .poa import PoAProposer, ProposerConfig
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
__all__ = ["PoAProposer", "ProposerConfig"]
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import hashlib
import time
from dataclasses import dataclass
from datetime import datetime
import re
@@ -11,6 +12,9 @@ from sqlmodel import Session, select
from ..logger import get_logger
from ..metrics import metrics_registry
from ..models import Block, Transaction
from ..gossip import gossip_broker
from ..mempool import get_mempool
_METRIC_KEY_SANITIZE = re.compile(r"[^0-9a-zA-Z]+")
@@ -19,8 +23,6 @@ _METRIC_KEY_SANITIZE = re.compile(r"[^0-9a-zA-Z]+")
def _sanitize_metric_suffix(value: str) -> str:
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
return sanitized or "unknown"
from ..models import Block
from ..gossip import gossip_broker
@dataclass
@@ -28,6 +30,47 @@ class ProposerConfig:
chain_id: str
proposer_id: str
interval_seconds: int
max_block_size_bytes: int = 1_000_000
max_txs_per_block: int = 500
class CircuitBreaker:
"""Circuit breaker for graceful degradation on repeated failures."""
def __init__(self, threshold: int = 5, timeout: int = 30) -> None:
self._threshold = threshold
self._timeout = timeout
self._failure_count = 0
self._last_failure_time: float = 0
self._state = "closed" # closed, open, half-open
@property
def state(self) -> str:
if self._state == "open":
if time.time() - self._last_failure_time >= self._timeout:
self._state = "half-open"
return self._state
def record_success(self) -> None:
self._failure_count = 0
self._state = "closed"
metrics_registry.set_gauge("circuit_breaker_state", 0.0)
def record_failure(self) -> None:
self._failure_count += 1
self._last_failure_time = time.time()
if self._failure_count >= self._threshold:
self._state = "open"
metrics_registry.set_gauge("circuit_breaker_state", 1.0)
metrics_registry.increment("circuit_breaker_trips_total")
def allow_request(self) -> bool:
state = self.state
if state == "closed":
return True
if state == "half-open":
return True
return False
class PoAProposer:
@@ -36,6 +79,7 @@ class PoAProposer:
*,
config: ProposerConfig,
session_factory: Callable[[], ContextManager[Session]],
circuit_breaker: Optional[CircuitBreaker] = None,
) -> None:
self._config = config
self._session_factory = session_factory
@@ -43,6 +87,7 @@ class PoAProposer:
self._stop_event = asyncio.Event()
self._task: Optional[asyncio.Task[None]] = None
self._last_proposer_id: Optional[str] = None
self._circuit_breaker = circuit_breaker or CircuitBreaker()
async def start(self) -> None:
if self._task is not None:
@@ -60,15 +105,31 @@ class PoAProposer:
await self._task
self._task = None
@property
def is_healthy(self) -> bool:
return self._circuit_breaker.state != "open"
async def _run_loop(self) -> None:
while not self._stop_event.is_set():
await self._wait_until_next_slot()
if self._stop_event.is_set():
break
try:
self._propose_block()
except Exception as exc: # pragma: no cover - defensive logging
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
metrics_registry.set_gauge("poa_proposer_running", 1.0)
try:
while not self._stop_event.is_set():
await self._wait_until_next_slot()
if self._stop_event.is_set():
break
if not self._circuit_breaker.allow_request():
self._logger.warning("Circuit breaker open, skipping block proposal")
metrics_registry.increment("blocks_skipped_circuit_breaker_total")
continue
try:
self._propose_block()
self._circuit_breaker.record_success()
except Exception as exc:
self._circuit_breaker.record_failure()
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
metrics_registry.increment("poa_propose_errors_total")
finally:
metrics_registry.set_gauge("poa_proposer_running", 0.0)
self._logger.info("PoA proposer loop exited")
async def _wait_until_next_slot(self) -> None:
head = self._fetch_chain_head()
@@ -85,6 +146,7 @@ class PoAProposer:
return
def _propose_block(self) -> None:
start_time = time.perf_counter()
with self._session_factory() as session:
head = session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
next_height = 0
@@ -95,6 +157,13 @@ class PoAProposer:
parent_hash = head.hash
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
# Drain transactions from mempool
mempool = get_mempool()
pending_txs = mempool.drain(
max_count=self._config.max_txs_per_block,
max_bytes=self._config.max_block_size_bytes,
)
timestamp = datetime.utcnow()
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
@@ -104,14 +173,33 @@ class PoAProposer:
parent_hash=parent_hash,
proposer=self._config.proposer_id,
timestamp=timestamp,
tx_count=0,
tx_count=len(pending_txs),
state_root=None,
)
session.add(block)
# Batch-insert transactions into the block
total_fees = 0
for ptx in pending_txs:
tx = Transaction(
tx_hash=ptx.tx_hash,
block_height=next_height,
sender=ptx.content.get("sender", ""),
recipient=ptx.content.get("recipient", ptx.content.get("payload", {}).get("recipient", "")),
payload=ptx.content,
)
session.add(tx)
total_fees += ptx.fee
session.commit()
# Metrics
build_duration = time.perf_counter() - start_time
metrics_registry.increment("blocks_proposed_total")
metrics_registry.set_gauge("chain_head_height", float(next_height))
metrics_registry.set_gauge("last_block_tx_count", float(len(pending_txs)))
metrics_registry.set_gauge("last_block_total_fees", float(total_fees))
metrics_registry.observe("block_build_duration_seconds", build_duration)
if interval_seconds is not None and interval_seconds >= 0:
metrics_registry.observe("block_interval_seconds", interval_seconds)
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
@@ -142,6 +230,9 @@ class PoAProposer:
"hash": block_hash,
"parent_hash": parent_hash,
"timestamp": timestamp.isoformat(),
"tx_count": len(pending_txs),
"total_fees": total_fees,
"build_ms": round(build_duration * 1000, 2),
},
)
@@ -180,8 +271,16 @@ class PoAProposer:
self._logger.info("Created genesis block", extra={"hash": genesis_hash})
def _fetch_chain_head(self) -> Optional[Block]:
with self._session_factory() as session:
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
for attempt in range(3):
try:
with self._session_factory() as session:
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
except Exception as exc:
if attempt == 2:
self._logger.error("Failed to fetch chain head after 3 attempts", extra={"error": str(exc)})
metrics_registry.increment("poa_db_errors_total")
return None
time.sleep(0.1 * (attempt + 1))
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()

View File

@@ -5,9 +5,10 @@ from contextlib import asynccontextmanager
from typing import Optional
from .config import settings
from .consensus import PoAProposer, ProposerConfig
from .consensus import PoAProposer, ProposerConfig, CircuitBreaker
from .database import init_db, session_scope
from .logger import get_logger
from .mempool import init_mempool
logger = get_logger(__name__)
@@ -20,6 +21,12 @@ class BlockchainNode:
async def start(self) -> None:
logger.info("Starting blockchain node", extra={"chain_id": settings.chain_id})
init_db()
init_mempool(
backend=settings.mempool_backend,
db_path=str(settings.db_path.parent / "mempool.db"),
max_size=settings.mempool_max_size,
min_fee=settings.min_fee,
)
self._start_proposer()
try:
await self._stop_event.wait()
@@ -39,8 +46,14 @@ class BlockchainNode:
chain_id=settings.chain_id,
proposer_id=settings.proposer_id,
interval_seconds=settings.block_time_seconds,
max_block_size_bytes=settings.max_block_size_bytes,
max_txs_per_block=settings.max_txs_per_block,
)
self._proposer = PoAProposer(config=proposer_config, session_factory=session_scope)
cb = CircuitBreaker(
threshold=settings.circuit_breaker_threshold,
timeout=settings.circuit_breaker_timeout,
)
self._proposer = PoAProposer(config=proposer_config, session_factory=session_scope, circuit_breaker=cb)
asyncio.create_task(self._proposer.start())
async def _shutdown(self) -> None:

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import hashlib
import json
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from threading import Lock
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from .metrics import metrics_registry
@@ -15,33 +15,233 @@ class PendingTransaction:
tx_hash: str
content: Dict[str, Any]
received_at: float
fee: int = 0
size_bytes: int = 0
def compute_tx_hash(tx: Dict[str, Any]) -> str:
canonical = json.dumps(tx, sort_keys=True, separators=(",", ":")).encode()
digest = hashlib.sha256(canonical).hexdigest()
return f"0x{digest}"
def _estimate_size(tx: Dict[str, Any]) -> int:
return len(json.dumps(tx, separators=(",", ":")).encode())
class InMemoryMempool:
def __init__(self) -> None:
"""In-memory mempool with fee-based prioritization and size limits."""
def __init__(self, max_size: int = 10_000, min_fee: int = 0) -> None:
self._lock = Lock()
self._transactions: Dict[str, PendingTransaction] = {}
self._max_size = max_size
self._min_fee = min_fee
def add(self, tx: Dict[str, Any]) -> str:
tx_hash = self._compute_hash(tx)
entry = PendingTransaction(tx_hash=tx_hash, content=tx, received_at=time.time())
fee = tx.get("fee", 0)
if fee < self._min_fee:
raise ValueError(f"Fee {fee} below minimum {self._min_fee}")
tx_hash = compute_tx_hash(tx)
size_bytes = _estimate_size(tx)
entry = PendingTransaction(
tx_hash=tx_hash, content=tx, received_at=time.time(),
fee=fee, size_bytes=size_bytes
)
with self._lock:
if tx_hash in self._transactions:
return tx_hash # duplicate
if len(self._transactions) >= self._max_size:
self._evict_lowest_fee()
self._transactions[tx_hash] = entry
metrics_registry.set_gauge("mempool_size", float(len(self._transactions)))
metrics_registry.increment("mempool_tx_added_total")
return tx_hash
def list_transactions(self) -> List[PendingTransaction]:
with self._lock:
return list(self._transactions.values())
def _compute_hash(self, tx: Dict[str, Any]) -> str:
canonical = json.dumps(tx, sort_keys=True, separators=(",", ":")).encode()
digest = hashlib.sha256(canonical).hexdigest()
return f"0x{digest}"
def drain(self, max_count: int, max_bytes: int) -> List[PendingTransaction]:
"""Drain transactions for block inclusion, prioritized by fee (highest first)."""
with self._lock:
sorted_txs = sorted(
self._transactions.values(),
key=lambda t: (-t.fee, t.received_at)
)
result: List[PendingTransaction] = []
total_bytes = 0
for tx in sorted_txs:
if len(result) >= max_count:
break
if total_bytes + tx.size_bytes > max_bytes:
continue
result.append(tx)
total_bytes += tx.size_bytes
for tx in result:
del self._transactions[tx.tx_hash]
metrics_registry.set_gauge("mempool_size", float(len(self._transactions)))
metrics_registry.increment("mempool_tx_drained_total", float(len(result)))
return result
def remove(self, tx_hash: str) -> bool:
with self._lock:
removed = self._transactions.pop(tx_hash, None) is not None
if removed:
metrics_registry.set_gauge("mempool_size", float(len(self._transactions)))
return removed
def size(self) -> int:
with self._lock:
return len(self._transactions)
def _evict_lowest_fee(self) -> None:
"""Evict the lowest-fee transaction to make room."""
if not self._transactions:
return
lowest = min(self._transactions.values(), key=lambda t: (t.fee, -t.received_at))
del self._transactions[lowest.tx_hash]
metrics_registry.increment("mempool_evictions_total")
_MEMPOOL = InMemoryMempool()
class DatabaseMempool:
"""SQLite-backed mempool for persistence and cross-service sharing."""
def __init__(self, db_path: str, max_size: int = 10_000, min_fee: int = 0) -> None:
import sqlite3
self._db_path = db_path
self._max_size = max_size
self._min_fee = min_fee
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._lock = Lock()
self._init_table()
def _init_table(self) -> None:
with self._lock:
self._conn.execute("""
CREATE TABLE IF NOT EXISTS mempool (
tx_hash TEXT PRIMARY KEY,
content TEXT NOT NULL,
fee INTEGER DEFAULT 0,
size_bytes INTEGER DEFAULT 0,
received_at REAL NOT NULL
)
""")
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_mempool_fee ON mempool(fee DESC)")
self._conn.commit()
def add(self, tx: Dict[str, Any]) -> str:
fee = tx.get("fee", 0)
if fee < self._min_fee:
raise ValueError(f"Fee {fee} below minimum {self._min_fee}")
tx_hash = compute_tx_hash(tx)
content = json.dumps(tx, sort_keys=True, separators=(",", ":"))
size_bytes = len(content.encode())
with self._lock:
# Check duplicate
row = self._conn.execute("SELECT 1 FROM mempool WHERE tx_hash = ?", (tx_hash,)).fetchone()
if row:
return tx_hash
# Evict if full
count = self._conn.execute("SELECT COUNT(*) FROM mempool").fetchone()[0]
if count >= self._max_size:
self._conn.execute("""
DELETE FROM mempool WHERE tx_hash = (
SELECT tx_hash FROM mempool ORDER BY fee ASC, received_at DESC LIMIT 1
)
""")
metrics_registry.increment("mempool_evictions_total")
self._conn.execute(
"INSERT INTO mempool (tx_hash, content, fee, size_bytes, received_at) VALUES (?, ?, ?, ?, ?)",
(tx_hash, content, fee, size_bytes, time.time())
)
self._conn.commit()
metrics_registry.increment("mempool_tx_added_total")
self._update_gauge()
return tx_hash
def list_transactions(self) -> List[PendingTransaction]:
with self._lock:
rows = self._conn.execute(
"SELECT tx_hash, content, fee, size_bytes, received_at FROM mempool ORDER BY fee DESC, received_at ASC"
).fetchall()
return [
PendingTransaction(
tx_hash=r[0], content=json.loads(r[1]),
fee=r[2], size_bytes=r[3], received_at=r[4]
) for r in rows
]
def drain(self, max_count: int, max_bytes: int) -> List[PendingTransaction]:
with self._lock:
rows = self._conn.execute(
"SELECT tx_hash, content, fee, size_bytes, received_at FROM mempool ORDER BY fee DESC, received_at ASC"
).fetchall()
result: List[PendingTransaction] = []
total_bytes = 0
hashes_to_remove: List[str] = []
for r in rows:
if len(result) >= max_count:
break
if total_bytes + r[3] > max_bytes:
continue
result.append(PendingTransaction(
tx_hash=r[0], content=json.loads(r[1]),
fee=r[2], size_bytes=r[3], received_at=r[4]
))
total_bytes += r[3]
hashes_to_remove.append(r[0])
if hashes_to_remove:
placeholders = ",".join("?" * len(hashes_to_remove))
self._conn.execute(f"DELETE FROM mempool WHERE tx_hash IN ({placeholders})", hashes_to_remove)
self._conn.commit()
metrics_registry.increment("mempool_tx_drained_total", float(len(result)))
self._update_gauge()
return result
def remove(self, tx_hash: str) -> bool:
with self._lock:
cursor = self._conn.execute("DELETE FROM mempool WHERE tx_hash = ?", (tx_hash,))
self._conn.commit()
removed = cursor.rowcount > 0
if removed:
self._update_gauge()
return removed
def size(self) -> int:
with self._lock:
return self._conn.execute("SELECT COUNT(*) FROM mempool").fetchone()[0]
def _update_gauge(self) -> None:
count = self._conn.execute("SELECT COUNT(*) FROM mempool").fetchone()[0]
metrics_registry.set_gauge("mempool_size", float(count))
def get_mempool() -> InMemoryMempool:
# Singleton
_MEMPOOL: Optional[InMemoryMempool | DatabaseMempool] = None
def init_mempool(backend: str = "memory", db_path: str = "", max_size: int = 10_000, min_fee: int = 0) -> None:
global _MEMPOOL
if backend == "database" and db_path:
_MEMPOOL = DatabaseMempool(db_path, max_size=max_size, min_fee=min_fee)
else:
_MEMPOOL = InMemoryMempool(max_size=max_size, min_fee=min_fee)
def get_mempool() -> InMemoryMempool | DatabaseMempool:
global _MEMPOOL
if _MEMPOOL is None:
_MEMPOOL = InMemoryMempool()
return _MEMPOOL

View File

@@ -449,7 +449,15 @@ async def send_transaction(request: TransactionRequest) -> Dict[str, Any]:
start = time.perf_counter()
mempool = get_mempool()
tx_dict = request.model_dump()
tx_hash = mempool.add(tx_dict)
try:
tx_hash = mempool.add(tx_dict)
except ValueError as e:
metrics_registry.increment("rpc_send_tx_rejected_total")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
metrics_registry.increment("rpc_send_tx_failed_total")
raise HTTPException(status_code=503, detail=f"Mempool unavailable: {e}")
recipient = request.payload.get("recipient", "")
try:
asyncio.create_task(
gossip_broker.publish(
@@ -457,7 +465,7 @@ async def send_transaction(request: TransactionRequest) -> Dict[str, Any]:
{
"tx_hash": tx_hash,
"sender": request.sender,
"recipient": request.recipient,
"recipient": recipient,
"payload": request.payload,
"nonce": request.nonce,
"fee": request.fee,
@@ -536,3 +544,63 @@ async def mint_faucet(request: MintFaucetRequest) -> Dict[str, Any]:
metrics_registry.increment("rpc_mint_faucet_success_total")
metrics_registry.observe("rpc_mint_faucet_duration_seconds", time.perf_counter() - start)
return {"address": request.address, "balance": updated_balance}
class ImportBlockRequest(BaseModel):
height: int
hash: str
parent_hash: str
proposer: str
timestamp: str
tx_count: int = 0
state_root: Optional[str] = None
transactions: Optional[list] = None
@router.post("/importBlock", summary="Import a block from a remote peer")
async def import_block(request: ImportBlockRequest) -> Dict[str, Any]:
from ..sync import ChainSync, ProposerSignatureValidator
from ..config import settings as cfg
metrics_registry.increment("rpc_import_block_total")
start = time.perf_counter()
trusted = [p.strip() for p in cfg.trusted_proposers.split(",") if p.strip()]
validator = ProposerSignatureValidator(trusted_proposers=trusted if trusted else None)
sync = ChainSync(
session_factory=session_scope,
chain_id=cfg.chain_id,
max_reorg_depth=cfg.max_reorg_depth,
validator=validator,
validate_signatures=cfg.sync_validate_signatures,
)
block_data = request.model_dump(exclude={"transactions"})
result = sync.import_block(block_data, request.transactions)
duration = time.perf_counter() - start
metrics_registry.observe("rpc_import_block_duration_seconds", duration)
if result.accepted:
metrics_registry.increment("rpc_import_block_accepted_total")
else:
metrics_registry.increment("rpc_import_block_rejected_total")
return {
"accepted": result.accepted,
"height": result.height,
"hash": result.block_hash,
"reason": result.reason,
"reorged": result.reorged,
"reorg_depth": result.reorg_depth,
}
@router.get("/syncStatus", summary="Get chain sync status")
async def sync_status() -> Dict[str, Any]:
from ..sync import ChainSync
from ..config import settings as cfg
metrics_registry.increment("rpc_sync_status_total")
sync = ChainSync(session_factory=session_scope, chain_id=cfg.chain_id)
return sync.get_sync_status()

View File

@@ -0,0 +1,324 @@
"""Chain synchronization with conflict resolution, signature validation, and metrics."""
from __future__ import annotations
import hashlib
import hmac
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from sqlmodel import Session, select
from .config import settings
from .logger import get_logger
from .metrics import metrics_registry
from .models import Block, Transaction
logger = get_logger(__name__)
@dataclass
class ImportResult:
accepted: bool
height: int
block_hash: str
reason: str
reorged: bool = False
reorg_depth: int = 0
class ProposerSignatureValidator:
"""Validates proposer signatures on imported blocks."""
def __init__(self, trusted_proposers: Optional[List[str]] = None) -> None:
self._trusted = set(trusted_proposers or [])
@property
def trusted_proposers(self) -> set:
return self._trusted
def add_trusted(self, proposer_id: str) -> None:
self._trusted.add(proposer_id)
def remove_trusted(self, proposer_id: str) -> None:
self._trusted.discard(proposer_id)
def validate_block_signature(self, block_data: Dict[str, Any]) -> Tuple[bool, str]:
"""Validate that a block was produced by a trusted proposer.
Returns (is_valid, reason).
"""
proposer = block_data.get("proposer", "")
block_hash = block_data.get("hash", "")
height = block_data.get("height", -1)
if not proposer:
return False, "Missing proposer field"
if not block_hash or not block_hash.startswith("0x"):
return False, f"Invalid block hash format: {block_hash}"
# If trusted list is configured, enforce it
if self._trusted and proposer not in self._trusted:
metrics_registry.increment("sync_signature_rejected_total")
return False, f"Proposer '{proposer}' not in trusted set"
# Verify block hash integrity
expected_fields = ["height", "parent_hash", "timestamp"]
for field in expected_fields:
if field not in block_data:
return False, f"Missing required field: {field}"
# Verify hash is a valid sha256 hex
hash_hex = block_hash[2:] # strip 0x
if len(hash_hex) != 64:
return False, f"Invalid hash length: {len(hash_hex)}"
try:
int(hash_hex, 16)
except ValueError:
return False, f"Invalid hex in hash: {hash_hex}"
metrics_registry.increment("sync_signature_validated_total")
return True, "Valid"
class ChainSync:
"""Handles block import with conflict resolution for divergent chains."""
def __init__(
self,
session_factory,
*,
chain_id: str = "",
max_reorg_depth: int = 10,
validator: Optional[ProposerSignatureValidator] = None,
validate_signatures: bool = True,
) -> None:
self._session_factory = session_factory
self._chain_id = chain_id or settings.chain_id
self._max_reorg_depth = max_reorg_depth
self._validator = validator or ProposerSignatureValidator()
self._validate_signatures = validate_signatures
def import_block(self, block_data: Dict[str, Any], transactions: Optional[List[Dict[str, Any]]] = None) -> ImportResult:
"""Import a block from a remote peer.
Handles:
- Normal append (block extends our chain)
- Fork resolution (block is on a longer chain)
- Duplicate detection
- Signature validation
"""
start = time.perf_counter()
height = block_data.get("height", -1)
block_hash = block_data.get("hash", "")
parent_hash = block_data.get("parent_hash", "")
proposer = block_data.get("proposer", "")
metrics_registry.increment("sync_blocks_received_total")
# Validate signature
if self._validate_signatures:
valid, reason = self._validator.validate_block_signature(block_data)
if not valid:
metrics_registry.increment("sync_blocks_rejected_total")
logger.warning("Block rejected: signature validation failed",
extra={"height": height, "reason": reason})
return ImportResult(accepted=False, height=height, block_hash=block_hash, reason=reason)
with self._session_factory() as session:
# Check for duplicate
existing = session.exec(
select(Block).where(Block.hash == block_hash)
).first()
if existing:
metrics_registry.increment("sync_blocks_duplicate_total")
return ImportResult(accepted=False, height=height, block_hash=block_hash,
reason="Block already exists")
# Get our chain head
our_head = session.exec(
select(Block).order_by(Block.height.desc()).limit(1)
).first()
our_height = our_head.height if our_head else -1
# Case 1: Block extends our chain directly
if height == our_height + 1:
parent_exists = session.exec(
select(Block).where(Block.hash == parent_hash)
).first()
if parent_exists or (height == 0 and parent_hash == "0x00"):
result = self._append_block(session, block_data, transactions)
duration = time.perf_counter() - start
metrics_registry.observe("sync_import_duration_seconds", duration)
return result
# Case 2: Block is behind our head — ignore
if height <= our_height:
# Check if it's a fork at a previous height
existing_at_height = session.exec(
select(Block).where(Block.height == height)
).first()
if existing_at_height and existing_at_height.hash != block_hash:
# Fork detected — resolve by longest chain rule
return self._resolve_fork(session, block_data, transactions, our_head)
metrics_registry.increment("sync_blocks_stale_total")
return ImportResult(accepted=False, height=height, block_hash=block_hash,
reason=f"Stale block (our height: {our_height})")
# Case 3: Block is ahead — we're behind, need to catch up
if height > our_height + 1:
metrics_registry.increment("sync_blocks_gap_total")
return ImportResult(accepted=False, height=height, block_hash=block_hash,
reason=f"Gap detected (our height: {our_height}, received: {height})")
return ImportResult(accepted=False, height=height, block_hash=block_hash,
reason="Unhandled import case")
def _append_block(self, session: Session, block_data: Dict[str, Any],
transactions: Optional[List[Dict[str, Any]]] = None) -> ImportResult:
"""Append a block to the chain tip."""
timestamp_str = block_data.get("timestamp", "")
try:
timestamp = datetime.fromisoformat(timestamp_str) if timestamp_str else datetime.utcnow()
except (ValueError, TypeError):
timestamp = datetime.utcnow()
tx_count = block_data.get("tx_count", 0)
if transactions:
tx_count = len(transactions)
block = Block(
height=block_data["height"],
hash=block_data["hash"],
parent_hash=block_data["parent_hash"],
proposer=block_data.get("proposer", "unknown"),
timestamp=timestamp,
tx_count=tx_count,
state_root=block_data.get("state_root"),
)
session.add(block)
# Import transactions if provided
if transactions:
for tx_data in transactions:
tx = Transaction(
tx_hash=tx_data.get("tx_hash", ""),
block_height=block_data["height"],
sender=tx_data.get("sender", ""),
recipient=tx_data.get("recipient", ""),
payload=tx_data,
)
session.add(tx)
session.commit()
metrics_registry.increment("sync_blocks_accepted_total")
metrics_registry.set_gauge("sync_chain_height", float(block_data["height"]))
logger.info("Imported block", extra={
"height": block_data["height"],
"hash": block_data["hash"],
"proposer": block_data.get("proposer"),
"tx_count": tx_count,
})
return ImportResult(
accepted=True, height=block_data["height"],
block_hash=block_data["hash"], reason="Appended to chain"
)
def _resolve_fork(self, session: Session, block_data: Dict[str, Any],
transactions: Optional[List[Dict[str, Any]]],
our_head: Block) -> ImportResult:
"""Resolve a fork using longest-chain rule.
For PoA, we use a simple rule: if the incoming block's height is at or below
our head and the parent chain is longer, we reorg. Otherwise, we keep our chain.
Since we only receive one block at a time, we can only detect the fork — actual
reorg requires the full competing chain. For now, we log the fork and reject
unless the block has a strictly higher height.
"""
fork_height = block_data.get("height", -1)
our_height = our_head.height
metrics_registry.increment("sync_forks_detected_total")
logger.warning("Fork detected", extra={
"fork_height": fork_height,
"our_height": our_height,
"fork_hash": block_data.get("hash"),
"our_hash": our_head.hash,
})
# Simple longest-chain: only reorg if incoming chain is strictly longer
# and within max reorg depth
if fork_height <= our_height:
return ImportResult(
accepted=False, height=fork_height,
block_hash=block_data.get("hash", ""),
reason=f"Fork rejected: our chain is longer or equal ({our_height} >= {fork_height})"
)
reorg_depth = our_height - fork_height + 1
if reorg_depth > self._max_reorg_depth:
metrics_registry.increment("sync_reorg_rejected_total")
return ImportResult(
accepted=False, height=fork_height,
block_hash=block_data.get("hash", ""),
reason=f"Reorg depth {reorg_depth} exceeds max {self._max_reorg_depth}"
)
# Perform reorg: remove blocks from fork_height onwards, then append
blocks_to_remove = session.exec(
select(Block).where(Block.height >= fork_height).order_by(Block.height.desc())
).all()
removed_count = 0
for old_block in blocks_to_remove:
# Remove transactions in the block
old_txs = session.exec(
select(Transaction).where(Transaction.block_height == old_block.height)
).all()
for tx in old_txs:
session.delete(tx)
session.delete(old_block)
removed_count += 1
session.commit()
metrics_registry.increment("sync_reorgs_total")
metrics_registry.observe("sync_reorg_depth", float(removed_count))
logger.warning("Chain reorg performed", extra={
"removed_blocks": removed_count,
"new_height": fork_height,
})
# Now append the new block
result = self._append_block(session, block_data, transactions)
result.reorged = True
result.reorg_depth = removed_count
return result
def get_sync_status(self) -> Dict[str, Any]:
"""Get current sync status and metrics."""
with self._session_factory() as session:
head = session.exec(
select(Block).order_by(Block.height.desc()).limit(1)
).first()
total_blocks = session.exec(select(Block)).all()
total_txs = session.exec(select(Transaction)).all()
return {
"chain_id": self._chain_id,
"head_height": head.height if head else -1,
"head_hash": head.hash if head else None,
"head_proposer": head.proposer if head else None,
"head_timestamp": head.timestamp.isoformat() if head else None,
"total_blocks": len(total_blocks),
"total_transactions": len(total_txs),
"validate_signatures": self._validate_signatures,
"trusted_proposers": list(self._validator.trusted_proposers),
"max_reorg_depth": self._max_reorg_depth,
}